Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WebNN EP] Use opSupportLimits to dynamically check data type support #22025

Merged
merged 5 commits into from
Sep 14, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 51 additions & 8 deletions onnxruntime/core/providers/webnn/builders/helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@
return true;
}

bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer,
const WebnnDeviceType device_type, const logging::Logger& logger) {
bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer, const WebnnDeviceType device_type,
const emscripten::val& wnn_limits, const logging::Logger& logger) {
const auto& op_builders = GetOpBuilders();
if (Contains(op_builders, node.OpType())) {
const auto* op_builder = op_builders.at(node.OpType());
return op_builder->IsOpSupported(graph_viewer.GetAllInitializedTensors(), node, device_type, logger);
return op_builder->IsOpSupported(graph_viewer.GetAllInitializedTensors(), node, device_type, wnn_limits, logger);
} else {
return false;
}
Expand Down Expand Up @@ -86,6 +86,7 @@
std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_viewer,
const emscripten::val& wnn_builder,
const WebnnDeviceType device_type,
const emscripten::val& wnn_limits,
const logging::Logger& logger) {
std::vector<std::vector<size_t>> supported_node_groups;

Expand All @@ -105,7 +106,7 @@
// Firstly check if platform supports the WebNN op.
if (CheckSingleOp(node->OpType(), wnn_builder, device_type)) {
LOGS(logger, VERBOSE) << "Operator type: [" << node->OpType() << "] is supported by browser";
supported = IsNodeSupported(*node, graph_viewer, device_type, logger);
supported = IsNodeSupported(*node, graph_viewer, device_type, wnn_limits, logger);
}

LOGS(logger, VERBOSE) << "Operator type: [" << node->OpType()
Expand All @@ -130,10 +131,52 @@
return supported_node_groups;
}

bool IsSupportedDataType(const int32_t data_type,
const std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType>& supported_data_types) {
return std::find(supported_data_types.begin(), supported_data_types.end(), data_type) !=
supported_data_types.end();
bool AreInputDataTypesSame(const std::string& op_type,
const std::vector<int32_t>& input_types,
Honry marked this conversation as resolved.
Show resolved Hide resolved
const logging::Logger& logger) {
for (size_t i = 1; i < input_types.size(); i++) {
if (input_types[0] != input_types[i]) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input data types should be the same.";
Honry marked this conversation as resolved.
Show resolved Hide resolved
return false;
}
}
return true;
}

bool IsSupportedDataType(const int32_t onnx_data_type, const emscripten::val& webnn_supported_data_types) {
auto it = onnx_to_webnn_data_type_map.find(static_cast<ONNX_NAMESPACE::TensorProto_DataType>(onnx_data_type));
if (it == onnx_to_webnn_data_type_map.end())
return false;

std::string webnn_data_type = it->second;

// Check if WebNN supports the data type.
emscripten::val is_supported = webnn_supported_data_types.call<emscripten::val>("includes",
emscripten::val(webnn_data_type));
fdwr marked this conversation as resolved.
Show resolved Hide resolved
return is_supported.as<bool>();
}

// Check if the input or output data type of ONNX node is supported by the WebNN operator.
bool IsDataTypeSupportedByOp(const std::string& onnx_op_type,
const int32_t onnx_data_type,
const emscripten::val& wnn_limits,
const std::string& webnn_input_output_name,
const std::string& onnx_input_output_name,
const logging::Logger& logger) {
std::string webnn_op_type;

Check warning on line 167 in onnxruntime/core/providers/webnn/builders/helper.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webnn/builders/helper.cc:167: Add #include <string> for string [build/include_what_you_use] [4]
if (!GetWebNNOpType(onnx_op_type, webnn_op_type))
return false;

if (!IsSupportedDataType(onnx_data_type, wnn_limits[webnn_op_type][webnn_input_output_name]["dataTypes"])) {
LOGS(logger, VERBOSE) << "[" << onnx_op_type
<< "] " << onnx_input_output_name
<< " type: [" << onnx_data_type
<< "] is not supported for now";
return false;
}

return true;
}

bool GetBidirectionalBroadcastShape(std::vector<int64_t>& shape_a,
Expand Down
43 changes: 31 additions & 12 deletions onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@
std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_viewer,
const emscripten::val& wnn_builder,
const WebnnDeviceType device_type,
const emscripten::val& wnn_limits,
const logging::Logger& logger);
static const InlinedHashMap<std::string, std::string> op_map = {
{"Abs", "abs"},
Expand Down Expand Up @@ -250,20 +251,38 @@
return true;
}

static const std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> webnn_supported_data_types = {
ONNX_NAMESPACE::TensorProto_DataType_BOOL,
ONNX_NAMESPACE::TensorProto_DataType_INT8,
ONNX_NAMESPACE::TensorProto_DataType_UINT8,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
ONNX_NAMESPACE::TensorProto_DataType_INT32,
ONNX_NAMESPACE::TensorProto_DataType_INT64,
ONNX_NAMESPACE::TensorProto_DataType_UINT32,
ONNX_NAMESPACE::TensorProto_DataType_UINT64,
inline bool GetWebNNOpType(const std::string& op_type, std::string& webnn_op_type) {
auto it = op_map.find(op_type);
// Returns false if the op_type is not listed in the op_map.
if (it == op_map.end()) {
return false;
}
webnn_op_type = it->second;
return true;
}

static const InlinedHashMap<ONNX_NAMESPACE::TensorProto_DataType, std::string> onnx_to_webnn_data_type_map = {
{ONNX_NAMESPACE::TensorProto_DataType_BOOL, "uint8"},
{ONNX_NAMESPACE::TensorProto_DataType_INT8, "int8"},
{ONNX_NAMESPACE::TensorProto_DataType_UINT8, "uint8"},
{ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, "float16"},
{ONNX_NAMESPACE::TensorProto_DataType_FLOAT, "float32"},
{ONNX_NAMESPACE::TensorProto_DataType_INT32, "int32"},
{ONNX_NAMESPACE::TensorProto_DataType_INT64, "int64"},
{ONNX_NAMESPACE::TensorProto_DataType_UINT32, "uint32"},
{ONNX_NAMESPACE::TensorProto_DataType_UINT64, "uint64"},
};

bool IsSupportedDataType(const int32_t data_type,
const std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType>& supported_data_types);
bool AreInputDataTypesSame(const std::string& op_type,
const std::vector<int32_t>& input_types,
const logging::Logger& logger);
bool IsSupportedDataType(const int32_t onnx_data_type, const emscripten::val& webnn_supported_data_types);
bool IsDataTypeSupportedByOp(const std::string& onnx_op_type,
const int32_t onnx_data_type,
const emscripten::val& wnn_limits,
const std::string& webnn_input_output_name,
const std::string& onnx_input_output_name,

Check warning on line 284 in onnxruntime/core/providers/webnn/builders/helper.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webnn/builders/helper.h:284: Add #include <string> for string [build/include_what_you_use] [4]
const logging::Logger& logger);

bool GetBidirectionalBroadcastShape(std::vector<int64_t>& shape_a,
std::vector<int64_t>& shape_b,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ class ActivationOpBuilder : public BaseOpBuilder {
// Operator support related.
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
WebnnDeviceType device_type, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
const logging::Logger& logger) const override;
};

// Add operator related.
Expand Down Expand Up @@ -94,44 +92,6 @@ bool ActivationOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initi
return true;
}

bool ActivationOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
Honry marked this conversation as resolved.
Show resolved Hide resolved
const logging::Logger& logger) const {
const auto& input = *node.InputDefs()[0];
const auto& op_type = node.OpType();
int32_t input_type;
if (!GetType(input, input_type, logger))
return false;

std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types;
// WebNN relu op supports float32, float16, int32, int8 input data types.
if (op_type == "Relu") {
supported_data_types = {
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
ONNX_NAMESPACE::TensorProto_DataType_INT32,
ONNX_NAMESPACE::TensorProto_DataType_INT8,
};
// WebNN CPU backend does not support int32 data type for relu.
if (device_type == WebnnDeviceType::CPU) {
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT32);
}
} else { // Others only support float32 and float16.
supported_data_types = {
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
};
}

if (!IsSupportedDataType(input_type, supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input type: [" << input_type
<< "] is not supported for now";
return false;
}

return true;
}

void CreateActivationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
if (op_registrations.op_builder_map.count(op_type) > 0)
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ class ArgMaxMinOpBuilder : public BaseOpBuilder {
// Operator support related.
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
WebnnDeviceType device_type, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
const logging::Logger& logger) const override;
};

// Add operator related.
Expand Down Expand Up @@ -77,31 +75,6 @@ bool ArgMaxMinOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initia
return true;
}

bool ArgMaxMinOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
const logging::Logger& logger) const {
const auto& input = *node.InputDefs()[0];
const auto& op_type = node.OpType();
int32_t input_type;
if (!GetType(input, input_type, logger))
return false;

std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types = webnn_supported_data_types;
// WebNN CPU backend doesn't support int64, uint64 input data types for argMax and argMin.
if (device_type == WebnnDeviceType::CPU) {
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT64);
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64);
}

if (!IsSupportedDataType(input_type, supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input type: [" << input_type
<< "] is not supported for now";
return false;
}

return true;
}

void CreateArgMaxMinOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
if (op_registrations.op_builder_map.count(op_type) > 0)
return;
Expand Down
52 changes: 25 additions & 27 deletions onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ bool HasExternalInitializer(const InitializedTensorSet& initializers, const Node
Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const {
ORT_RETURN_IF_NOT(
IsOpSupported(model_builder.GetInitializerTensors(), node, model_builder.GetWebnnDeviceType(), logger),
"Unsupported operator ",
node.OpType());
IsOpSupported(model_builder.GetInitializerTensors(), node, model_builder.GetWebnnDeviceType(),
model_builder.GetOpSupportLimits(), logger),
"Unsupported operator ", node.OpType());
ORT_RETURN_IF_ERROR(AddToModelBuilderImpl(model_builder, node, logger));
LOGS(logger, VERBOSE) << "Operator name: [" << node.Name()
<< "] type: [" << node.OpType() << "] was added";
Expand All @@ -50,8 +50,12 @@ Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const Node&
// Operator support related.

bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, const Node& node,
const WebnnDeviceType device_type, const logging::Logger& logger) const {
if (!HasSupportedInputs(node, device_type, logger))
const WebnnDeviceType device_type, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
if (!HasSupportedInputs(node, wnn_limits, logger))
return false;

if (!HasSupportedOutputsImpl(node, wnn_limits, logger))
return false;

// We do not support external initializers for now.
Expand All @@ -64,7 +68,7 @@ bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, cons
return IsOpSupportedImpl(initializers, node, device_type, logger);
}

bool BaseOpBuilder::HasSupportedInputs(const Node& node, const WebnnDeviceType device_type,
bool BaseOpBuilder::HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]");
for (const auto* input : node.InputDefs()) {
Expand All @@ -73,39 +77,33 @@ bool BaseOpBuilder::HasSupportedInputs(const Node& node, const WebnnDeviceType d
}
}

// WebNN CPU backend (TFLite) will enable float16 input data type soon,
// temporarily fallback float16 input data type for WebNN CPU.
if (device_type == WebnnDeviceType::CPU) {
const auto& input = *node.InputDefs()[0];

int32_t input_type;
if (!GetType(input, input_type, logger))
return false;
if (input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16)
return false;
}

return HasSupportedInputsImpl(node, device_type, logger);
return HasSupportedInputsImpl(node, wnn_limits, logger);
}

bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node,
const WebnnDeviceType /* device_type */,
const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
// We only check the type of input 0 by default, specific op builder can override this.
const auto& input = *node.InputDefs()[0];

const auto& op_type = node.OpType();
int32_t input_type;
if (!GetType(input, input_type, logger))
return false;

if (!IsSupportedDataType(input_type, webnn_supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << node.OpType()
<< "] Input type: [" << input_type
<< "] is not supported for now";
return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "Input", logger);
}

bool BaseOpBuilder::HasSupportedOutputsImpl(const Node& node,
const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
// We only check the type of output 0 by default, specific op builder can override this.
const auto& output = *node.OutputDefs()[0];
const auto& op_type = node.OpType();
int32_t output_type;
fdwr marked this conversation as resolved.
Show resolved Hide resolved
if (!GetType(output, output_type, logger))
fdwr marked this conversation as resolved.
Show resolved Hide resolved
return false;
}

return true;
return IsDataTypeSupportedByOp(op_type, output_type, wnn_limits, "output", "Output", logger);
}

bool BaseOpBuilder::HasSupportedOpSet(const Node& node,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,19 @@ class BaseOpBuilder : public IOpBuilder {
// Operator support related.
public:
bool IsOpSupported(const InitializedTensorSet& initializers, const Node& node,
const WebnnDeviceType device_type, const logging::Logger& logger) const override;
const WebnnDeviceType device_type, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;

protected:
virtual bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& /* node */,
const WebnnDeviceType /* device_type */, const logging::Logger& /* logger */) const {
return true;
}

virtual bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
virtual bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const;
virtual bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const;

// ONNX Runtime only *guarantees* support for models stamped
// with opset version 7 or above for opset domain 'ai.onnx'.
Expand All @@ -50,7 +53,7 @@ class BaseOpBuilder : public IOpBuilder {

private:
bool HasSupportedOpSet(const Node& node, const logging::Logger& logger) const;
bool HasSupportedInputs(const Node& node, const WebnnDeviceType device_type, const logging::Logger& logger) const;
bool HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const;
};

} // namespace webnn
Expand Down
Loading
Loading