Skip to content

Commit

Permalink
[WebNN EP] Remove activation fusion (#20635)
Browse files Browse the repository at this point in the history
WebNN spec has removed activation option for conv and
batchNormalization. We don't need additional activation fusion in WebNN
EP anymore.

[edit by fdwr] Note this is handled in the browser now, which knows more
about the backend platform version and can more safely make decisions
about which fusions are possible (e.g. for the DirectML backend, whether
softmax and gelu can fuse successfully with their base operator).
  • Loading branch information
Honry authored May 15, 2024
1 parent d1e66f0 commit f5bfbd6
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 175 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,40 +32,35 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
emscripten::val input = model_builder.GetOperand(node.InputDefs()[0]->Name());
emscripten::val output = emscripten::val::object();

if (Contains(model_builder.GetFusedActivations(), node.InputDefs()[0]->Name())) {
LOGS_DEFAULT(VERBOSE) << op_type << " Node [" << node.Name() << "] fused";
output = input;
NodeAttrHelper helper(node);
emscripten::val options = emscripten::val::object();
if (op_type == "Elu") {
options.set("alpha", helper.Get("alpha", 1.0f));
output = model_builder.GetBuilder().call<emscripten::val>("elu", input, options);
} else if (op_type == "Gelu") {
output = model_builder.GetBuilder().call<emscripten::val>("gelu", input, options);
} else if (op_type == "HardSigmoid") {
options.set("alpha", helper.Get("alpha", 0.2f));
options.set("beta", helper.Get("beta", 0.5f));
output = model_builder.GetBuilder().call<emscripten::val>("hardSigmoid", input, options);
} else if (op_type == "HardSwish") {
output = model_builder.GetBuilder().call<emscripten::val>("hardSwish", input);
} else if (op_type == "LeakyRelu") {
options.set("alpha", helper.Get("alpha", 0.0f));
output = model_builder.GetBuilder().call<emscripten::val>("leakyRelu", input, options);
} else if (op_type == "Relu") {
output = model_builder.GetBuilder().call<emscripten::val>("relu", input);
} else if (op_type == "Sigmoid") {
output = model_builder.GetBuilder().call<emscripten::val>("sigmoid", input);
} else if (op_type == "Softplus") {
output = model_builder.GetBuilder().call<emscripten::val>("softplus", input);
} else if (op_type == "Softsign") {
output = model_builder.GetBuilder().call<emscripten::val>("softsign", input);
} else if (op_type == "Tanh") {
output = model_builder.GetBuilder().call<emscripten::val>("tanh", input);
} else {
NodeAttrHelper helper(node);
emscripten::val options = emscripten::val::object();
if (op_type == "Elu") {
options.set("alpha", helper.Get("alpha", 1.0f));
output = model_builder.GetBuilder().call<emscripten::val>("elu", input, options);
} else if (op_type == "Gelu") {
output = model_builder.GetBuilder().call<emscripten::val>("gelu", input, options);
} else if (op_type == "HardSigmoid") {
options.set("alpha", helper.Get("alpha", 0.2f));
options.set("beta", helper.Get("beta", 0.5f));
output = model_builder.GetBuilder().call<emscripten::val>("hardSigmoid", input, options);
} else if (op_type == "HardSwish") {
output = model_builder.GetBuilder().call<emscripten::val>("hardSwish", input);
} else if (op_type == "LeakyRelu") {
options.set("alpha", helper.Get("alpha", 0.0f));
output = model_builder.GetBuilder().call<emscripten::val>("leakyRelu", input, options);
} else if (op_type == "Relu") {
output = model_builder.GetBuilder().call<emscripten::val>("relu", input);
} else if (op_type == "Sigmoid") {
output = model_builder.GetBuilder().call<emscripten::val>("sigmoid", input);
} else if (op_type == "Softplus") {
output = model_builder.GetBuilder().call<emscripten::val>("softplus", input);
} else if (op_type == "Softsign") {
output = model_builder.GetBuilder().call<emscripten::val>("softsign", input);
} else if (op_type == "Tanh") {
output = model_builder.GetBuilder().call<emscripten::val>("tanh", input);
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"ActivationOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type);
}
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"ActivationOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type);
}

model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,7 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
options.set("minValue", minValue);
options.set("maxValue", maxValue);
emscripten::val input = model_builder.GetOperand(input_name);
emscripten::val output = emscripten::val::object();
if (Contains(model_builder.GetFusedActivations(), input_name)) {
LOGS_DEFAULT(VERBOSE) << "Clip Node [" << node.Name() << "] fused";
output = input;
} else {
output = model_builder.GetBuilder().call<emscripten::val>("clamp", input, options);
}
emscripten::val output = model_builder.GetBuilder().call<emscripten::val>("clamp", input, options);

model_builder.AddOperand(output_name, std::move(output));
return Status::OK();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,6 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder,
options.set("bias", model_builder.GetOperand(input_defs[2]->Name()));
}

emscripten::val activation = model_builder.FindActivation(node, *node.OutputDefs()[0]);
if (emscripten::val::null() != activation) {
options.set("activation", activation);
}

return Status::OK();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,7 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder
if (model_builder.GetPreferredLayout() == DataLayout::NHWC) {
options.set("axis", rank - 1);
}
emscripten::val activation = model_builder.FindActivation(node, *node.OutputDefs()[0]);
if (emscripten::val::null() != activation) {
options.set("activation", activation);
}

output = model_builder.GetBuilder().call<emscripten::val>("batchNormalization", input, mean, variance, options);
} else if (op_type == "LayerNormalization") {
int64_t axis = helper.Get("axis", -1);
Expand Down
112 changes: 0 additions & 112 deletions onnxruntime/core/providers/webnn/builders/model_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logge

Status ModelBuilder::Initialize() {
PreprocessInitializers();
PreprocessActivations();
ORT_RETURN_IF_ERROR(RegisterInitializers());
ORT_RETURN_IF_ERROR(RegisterModelInputs());
ORT_RETURN_IF_ERROR(AddOperations());
Expand Down Expand Up @@ -78,79 +77,6 @@ void ModelBuilder::PreprocessInitializers() {
}
}

void ModelBuilder::PreprocessActivations() {
const auto& node_indices = graph_viewer_.GetNodesInTopologicalOrder();

if (wnn_device_type_ == WebnnDeviceType::CPU) {
// WebNN CPU currently only supports "Relu" and "Clip" fusion.
supported_activation_nodes_ = {"Clip", "Relu"};
} else {
supported_activation_nodes_ = {
// Temporarily disable clamp fusion for WebNN GPU as which is not supported yet.
// "Clip",
"Elu",
"Gelu",
"HardSigmoid",
"HardSwish",
"Relu",
"LeakyRelu",
"Sigmoid",
"Softplus",
"Softsign",
"Tanh",
};
}

for (size_t i = 0; i < node_indices.size(); i++) {
const auto* node(graph_viewer_.GetNode(node_indices[i]));
const auto& op_type(node->OpType());

// Ignore unsupported activation nodes.
if (!Contains(supported_activation_nodes_, op_type)) {
continue;
}

if (op_type == "Clip") {
float minValue, maxValue;
GetClipMinMax(GetInitializerTensors(), *node, minValue, maxValue, logger_);
emscripten::val options = emscripten::val::object();
options.set("minValue", minValue);
options.set("maxValue", maxValue);
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("clamp", options));
} else if (op_type == "Elu") {
NodeAttrHelper helper(*node);
emscripten::val options = emscripten::val::object();
options.set("alpha", helper.Get("alpha", 1.0f));
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("elu", options));
} else if (op_type == "Gelu") {
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("gelu"));
} else if (op_type == "HardSigmoid") {
NodeAttrHelper helper(*node);
emscripten::val options = emscripten::val::object();
options.set("alpha", helper.Get("alpha", 0.2f));
options.set("beta", helper.Get("beta", 0.5f));
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("hardSigmoid", options));
} else if (op_type == "HardSwish") {
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("hardSwish"));
} else if (op_type == "Relu") {
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("relu"));
} else if (op_type == "LeakyRelu") {
NodeAttrHelper helper(*node);
emscripten::val options = emscripten::val::object();
options.set("alpha", helper.Get("alpha", 0.0f));
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("leakyRelu", options));
} else if (op_type == "Sigmoid") {
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("sigmoid"));
} else if (op_type == "Softplus") {
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("softplus"));
} else if (op_type == "Softsign") {
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("softsign"));
} else if (op_type == "Tanh") {
activation_nodes_.emplace(node->Index(), wnn_builder_.call<emscripten::val>("tanh"));
}
}
}

Status ModelBuilder::RegisterInitializers() {
for (const auto& pair : GetInitializerTensors()) {
const auto& tensor = *pair.second;
Expand Down Expand Up @@ -421,44 +347,6 @@ Status ModelBuilder::Compile(std::unique_ptr<Model>& model) {
return Status::OK();
}

emscripten::val ModelBuilder::FindActivation(const Node& node, const NodeArg& output) {
emscripten::val fused_op = emscripten::val::null();
for (auto it = node.OutputEdgesBegin(), end = node.OutputEdgesEnd(); it != end; ++it) {
const auto& dst_node = it->GetNode();
const auto* dst_input = dst_node.InputDefs()[it->GetDstArgIndex()];
if (!Contains(supported_activation_nodes_, dst_node.OpType())) {
return emscripten::val::null();
}
if (Contains(activation_nodes_, dst_node.Index())) {
if (&output == dst_input) {
fused_op = activation_nodes_.at(dst_node.Index());
}
} else {
// If there is any other non-relu node using the output
// will add relu separately.
if (&output == dst_input) {
return emscripten::val::null();
}
}
}

// If output is a graph output, will add relu separately.
if (fused_op != emscripten::val::null()) {
for (const auto* graph_output : graph_viewer_.GetOutputs()) {
if (&output == graph_output) {
return emscripten::val::null();
}
}

LOGS_DEFAULT(VERBOSE) << "Node [" << node.Name() << "] type [" << node.OpType()
<< "], fused the output [" << output.Name() << "]";

fused_activations_.insert(output.Name());
}

return fused_op;
}

void ModelBuilder::AddScalarOutput(const std::string& output_name) {
scalar_outputs_.insert(output_name);
}
Expand Down
14 changes: 0 additions & 14 deletions onnxruntime/core/providers/webnn/builders/model_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,6 @@ class ModelBuilder {
Status AddOperandFromPersistMemoryBuffer(
const std::string& name, const void* buffer,
const size_t size, const std::vector<uint32_t> shape, const int32_t data_type);
// Find if an output has a fuseable activation (e.g., Relu).
emscripten::val FindActivation(const Node& node, const NodeArg& output);

const InlinedHashSet<std::string>&
GetFusedActivations() const { return fused_activations_; }

DataLayout GetPreferredLayout() const { return preferred_layout_; }

Expand Down Expand Up @@ -82,22 +77,13 @@ class ModelBuilder {
InlinedHashSet<std::string> skipped_initializers_;
InlinedHashSet<std::string> skipped_inputs_;

InlinedHashSet<std::string> fused_activations_;

InlinedHashSet<std::string> supported_activation_nodes_;

uint32_t name_token_{0};
InlinedHashSet<std::string> unique_names_;

// All activation nodes (e.g., Relu) as a map <NodeIndex, FusionOperator>.
InlinedHashMap<NodeIndex, emscripten::val> activation_nodes_;

// Convert the onnx model to WebNN operands
Status Initialize() ORT_MUST_USE_RESULT;

void PreprocessInitializers();
// Preprocess all the activation nodes (e.g., Relu) for easy query later.
void PreprocessActivations();

// Copy and process all the initializers to WebNN constants.
Status RegisterInitializers() ORT_MUST_USE_RESULT;
Expand Down

0 comments on commit f5bfbd6

Please sign in to comment.