Skip to content

Commit

Permalink
Support opSupportLimits for Gru
Browse files Browse the repository at this point in the history
  • Loading branch information
Honry committed Sep 11, 2024
1 parent 75319a1 commit 8b896d9
Showing 1 changed file with 3 additions and 16 deletions.
19 changes: 3 additions & 16 deletions onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class GruOpBuilder : public BaseOpBuilder {
private:
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
const WebnnDeviceType /*device_type*/, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
};

Expand Down Expand Up @@ -185,7 +185,7 @@ bool GruOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, c
return true;
}

bool GruOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
bool GruOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
Expand All @@ -208,20 +208,7 @@ bool GruOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceTyp
return false;
}

std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types;
if (device_type == WebnnDeviceType::CPU) {
// WebNN CPU backend only support float32 input data type.
supported_data_types = {
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
};
} else if (device_type == WebnnDeviceType::GPU) {
supported_data_types = {
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
};
}

if (!IsSupportedDataType(input0_type, supported_data_types)) {
if (!IsSupportedDataType(input0_type, wnn_limits["gru"]["input"]["dataTypes"])) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input type: [" << input0_type
<< "] is not supported for now";
Expand Down

0 comments on commit 8b896d9

Please sign in to comment.