Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WebNN EP] Support Einsum op #19558

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#include "core/providers/webnn/builders/model_builder.h"
#include "core/providers/webnn/builders/op_builder_factory.h"

#include "base_op_builder.h"

Check warning on line 15 in onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Include the directory when naming header files [build/include_subdir] [4] Raw Output: onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc:15: Include the directory when naming header files [build/include_subdir] [4]

namespace onnxruntime {
namespace webnn {
Expand Down Expand Up @@ -62,7 +62,7 @@
};

bool ParseEquationComponents(const Node& node,
const std::string& equation,
const std::string_view& equation,
fdwr marked this conversation as resolved.
Show resolved Hide resolved
std::vector<uint32_t>& label_indices,
std::vector<Component>& components,
std::vector<uint32_t>& output_dimensions,
Expand All @@ -73,7 +73,7 @@
// Read first to last character in equation, looking for letters, commas, and one arrow.
// The ellipsis is not supported.
std::map<char, uint32_t> label_maps;
std::set<char> repeated_labels;

Check warning on line 76 in onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <set> for set<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc:76: Add #include <set> for set<> [build/include_what_you_use] [4]

num_labels = 0;
Component current_component = {};
Expand Down Expand Up @@ -563,8 +563,8 @@

// transpose output
std::vector<uint32_t> output_labels_sorted(kept_axes.begin(), kept_axes.end());
std::map<uint32_t, uint32_t> mapping;

Check warning on line 566 in onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <map> for map<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc:566: Add #include <map> for map<> [build/include_what_you_use] [4]
std::sort(output_labels_sorted.begin(), output_labels_sorted.end());

Check warning on line 567 in onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <algorithm> for sort [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc:567: Add #include <algorithm> for sort [build/include_what_you_use] [4]

auto equals = [](std::vector<uint32_t> a, gsl::span<const uint32_t> b) {
return std::equal(a.begin(), a.end(), b.begin(), b.end());
Expand Down Expand Up @@ -689,7 +689,7 @@
break;
}

model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));

Check warning on line 692 in onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <utility> for move [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc:692: Add #include <utility> for move [build/include_what_you_use] [4]
return Status::OK();
}

Expand All @@ -702,7 +702,7 @@
const auto& input_defs = node.InputDefs();

if (input_defs.size() > 2) {
// TODO: Support more than two inputs.

Check warning on line 705 in onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc:705: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
LOGS(logger, VERBOSE) << "EinSum only supports up to two inputs.";
return false;
}
Expand Down Expand Up @@ -759,7 +759,7 @@
const auto equation = helper.Get("equation", std::string(" "));
std::vector<uint32_t> label_indices;
std::vector<Component> components;
std::vector<uint32_t> output_dimensions;

Check warning on line 762 in onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc:762: Add #include <vector> for vector<> [build/include_what_you_use] [4]
uint32_t num_labels;

if (!ParseEquationComponents(node, equation, label_indices,
Expand All @@ -784,7 +784,7 @@
}
}

void CreateEinsumOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {

Check warning on line 787 in onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc:787: Add #include <string> for string [build/include_what_you_use] [4]
op_registrations.builders.push_back(std::make_unique<EinsumOpBuilder>());
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
}
Expand Down
Loading