Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IREE EP] Fix ElementsAttr iteration error from MLIR. #10

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions onnxruntime/core/providers/get_execution_providers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,14 @@ constexpr ProviderInfo kProvidersInPriorityOrder[] =
true,
#else
false,
#endif
},
{
kIreeExecutionProvider,
#ifdef USE_IREE
true,
#else
false,
#endif
},
{kCpuExecutionProvider, true}, // kCpuExecutionProvider is always last
Expand Down
30 changes: 8 additions & 22 deletions onnxruntime/core/providers/iree/compiler/jit_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
// (which require a pre-compilation step).

#include "core/providers/iree/compiler/jit_compiler.h"
#include "core/graph/graph_proto_serializer.h"
#include "core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.h"
#include "mlir-c/BuiltinAttributes.h"

Expand Down Expand Up @@ -157,13 +158,12 @@ common::Status CompilerInvocation::ImportSubgraph(const onnxruntime::GraphViewer
opset_import->set_version(it.second);
}

// Unforgivably sharp edge: There is a ToGraphProto() that returns a value and another that returns a reference.
// And they differ by const-ness. We need to make sure we get the reference, obviously, so we assign it explicitly.
const ONNX_NAMESPACE::GraphProto& graph_proto = graph_view.GetGraph().ToGraphProto();
ONNX_NAMESPACE::GraphProto graph_proto;
GraphViewerToProto(graph_view, graph_proto, false, false);
// LOGS(session.logger, INFO) << " full graph: " << graph_proto.DebugString();

// Set up for subgraph import.
torch_mlir_onnx::GraphInfo subgraph_info(model_info, graph_proto);
torch_mlir_onnx::GraphInfo subgraph_info(graph_view, model_info, graph_proto);
if (torch_mlir_onnx::failed(subgraph_info.Initialize())) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, model_info.error_message());
}
Expand Down Expand Up @@ -193,24 +193,10 @@ common::Status CompilerInvocation::ImportSubgraph(const onnxruntime::GraphViewer
model_info.error_message(), ConsumeDiagnostics());
}

// Import each node. Note that the importer uses references internally and expects nodes to be located at fixed
// memory locations for the life of iteration. So we materialize them into a fixed vector first. This is because
// the onnxruntime does not keep the serialized proto form sync'd on its own.
auto node_indices = graph_view.GetNodesInTopologicalOrder();
std::vector<ONNX_NAMESPACE::NodeProto> nodes(node_indices.size());
for (size_t i = 0; i < node_indices.size(); ++i) {
graph_view.GetNode(node_indices[i])->ToProto(nodes[i]);
}
for (const auto& node : nodes) {
if (torch_mlir_onnx::failed(imp.ImportNode(node))) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to import node '", node.name(), "': ",
model_info.error_message(), " (node:\n", node.DebugString(), "\n)", ConsumeDiagnostics());
}
}

// Finalize.
if (torch_mlir_onnx::failed(imp.FinalizeGraph())) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, model_info.error_message(), ConsumeDiagnostics());
if (torch_mlir_onnx::failed(imp.ImportAll())) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to import nodes",
": ", model_info.error_message(),
ConsumeDiagnostics());
}

// Verify the function at the point of import because we have better diagnostics.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,11 @@ void ModelInfo::DebugDumpProto() {
fprintf(stderr, "%s\n", debug_string.c_str());
}

Status ModelInfo::Initialize() {
Status ModelInfo::Initialize(const onnxruntime::GraphViewer &gv) {
if (!model_proto_.has_graph()) {
return SetError("ONNX ModelProto has no main graph");
}
main_graph_ = std::make_unique<GraphInfo>(*this, model_proto_.graph());
main_graph_ = std::make_unique<GraphInfo>(gv, *this, model_proto_.graph());
Shukla-Gaurav marked this conversation as resolved.
Show resolved Hide resolved
if (failed(main_graph_->Initialize())) {
return failure();
}
Expand Down Expand Up @@ -228,33 +228,25 @@ Status GraphInfo::Initialize() {
}

const onnx::TypeProto *GraphInfo::FindTypeProtoForName(std::string_view name) {
// Node outputs don't typically have type information, but shape inference
// will associate them in the value_info. If not there, it may be a
// graph output, which must have type information.
{
auto it = value_info_map_.find(name);
if (it != value_info_map_.end()) {
return &it->second.type();
}
}
{
auto it = output_map_.find(name);
if (it != output_map_.end()) {
return &it->second.type();
}
}

std::string msg = "No type information associated with '";
msg.append(name);
msg.append("'. Run shape inference?");
model_info_.SetError(std::move(msg));
return nullptr;
return graph_viewer_.GetNodeArg(std::string{name})->TypeAsProto();
}

// ---------------------------------------------------------------------------//
// ContextCache
// ---------------------------------------------------------------------------//

// Parsing !torch.none to an MlirType (this is used as the result type for the
// GetNoneNode op).
MlirType ContextCache::GetNoneType() {
auto t =
mlirTypeParseGet(context_, mlirStringRefCreateFromCString("!torch.none"));
if (mlirTypeIsNull(t)) {
std::string message = "internal error: could not parse !torch.none type: ";
model_info_.SetError(std::move(message));
}
return t;
}

MlirType ContextCache::ConvertTypeProto(const onnx::TypeProto &tp) {
if (tp.has_tensor_type()) {
// Convert Tensor TypeProto.
Expand Down Expand Up @@ -392,8 +384,8 @@ ContextCache::ConvertTensorProtoToAttr(const onnx::TensorProto &tp) {
int8_conversion.reserve(tp.int32_data_size());
for (int32_t v : tp.int32_data())
int8_conversion.push_back(v);
return mlirDenseElementsAttrInt8Get(
tensor_type, int8_conversion.size(), int8_conversion.data());
return mlirDenseElementsAttrInt8Get(tensor_type, int8_conversion.size(),
int8_conversion.data());
}
case onnx::TensorProto::DataType::TensorProto_DataType_INT32:
return mlirDenseElementsAttrInt32Get(tensor_type, tp.int32_data_size(),
Expand Down Expand Up @@ -511,6 +503,19 @@ NodeImporter::NodeImporter(GraphInfo &graph_info, ContextCache &cc,
/*childLoc=*/{nullptr});
}

// For importing the !torch.none in place of
// '' -> that is the empty label.
void NodeImporter::ImportNoneNode() {
auto it = nv_map_.find("");
if (it != nv_map_.end())
return;

MlirOperation new_op = createMlirOperationAtEnd(
body_block_, "torch.constant.none", default_loc_, cc_.GetNoneType());
MlirValue nne = mlirOperationGetResult(new_op, 0);
nv_map_.emplace("", nne);
}

Status NodeImporter::DefineFunction(std::optional<std::string> name,
MlirOperation *out_function_op) {
const onnx::GraphProto &p = graph_info_.graph_proto();
Expand All @@ -529,16 +534,16 @@ Status NodeImporter::DefineFunction(std::optional<std::string> name,
std::vector<MlirType> input_types;
std::vector<MlirLocation> input_locs;
std::vector<MlirType> output_types;
for (auto *input : graph_info_.inputs()) {
MlirType t = cc_.ConvertTypeProto(input->type());
for (auto *input : graph_info_.graph_viewer().GetInputs()) {
MlirType t = cc_.ConvertTypeProto(*input->TypeAsProto());
if (mlirTypeIsNull(t)) {
return failure();
}
input_types.push_back(t);
input_locs.push_back(default_loc_);
}
for (auto *output : graph_info_.outputs()) {
MlirType t = cc_.ConvertTypeProto(output->type());
for (auto output : graph_info_.graph_proto().output()) {
MlirType t = cc_.ConvertTypeProto(output.type());
if (mlirTypeIsNull(t)) {
return failure();
}
Expand All @@ -561,8 +566,9 @@ Status NodeImporter::DefineFunction(std::optional<std::string> name,
mlirRegionAppendOwnedBlock(bodyRegion, body_block_);

// Map the block args to names and store for evaluation.
for (int i = 0, e = graph_info_.inputs().size(); i < e; ++i) {
std::string_view name = graph_info_.inputs()[i]->name();
for (int i = 0, e = graph_info_.graph_viewer().GetInputs().size(); i < e;
++i) {
std::string_view name = graph_info_.graph_viewer().GetInputs()[i]->Name();
MlirValue value = mlirBlockGetArgument(body_block_, i);
nv_map_[name] = value;
}
Expand Down Expand Up @@ -622,15 +628,19 @@ void NodeImporter::PopulateGraphAttrs(MlirOperation container_op) {
}

Status NodeImporter::ImportAll() {
// TODO: Consider pulling in initializers on demand since there can be so
// much unused crap.
for (auto it : graph_info_.initializer_map()) {
if (failed(ImportInitializer(it.second)))
return failure();
ImportNoneNode();

auto node_indices = graph_info_.graph_viewer().GetNodesInTopologicalOrder();
std::vector<ONNX_NAMESPACE::NodeProto> nodes(node_indices.size());
for (size_t i = 0; i < node_indices.size(); ++i) {
graph_info_.graph_viewer().GetNode(node_indices[i])->ToProto(nodes[i]);
}
for (auto it : graph_info_.graph_proto().node()) {
if (failed(ImportNode(it)))
return failure();

for (const auto &node : nodes) {
if (torch_mlir_onnx::failed(ImportNode(node))) {
return SetError("Failed to import node '" + node.name() +
"': " + "(node:\n" + node.DebugString() + "\n)");
}
}

return FinalizeGraph();
Expand All @@ -640,8 +650,8 @@ Status NodeImporter::FinalizeGraph() {
// Lookup the outputs, which should all be in the nv_map if the graph was
// properly formed.
std::vector<MlirValue> output_values;
for (const auto *output : graph_info_.outputs()) {
std::string_view name = output->name();
for (const auto &output : graph_info_.graph_proto().output()) {
std::string_view name = output.name();
auto found_it = nv_map_.find(name);
if (found_it == nv_map_.end()) {
std::string msg = "Non topologically produced ONNX graph output '";
Expand Down Expand Up @@ -670,8 +680,11 @@ Status NodeImporter::ImportInitializer(const onnx::TensorProto &initializer) {
return failure();

MlirOperation op = createMlirOperationAtEnd(
body_block_, "torch.vtensor.literal", loc, vtensor_type,
toMlirNamedAttribute("value", value_attr));
body_block_, "torch.operator", loc, vtensor_type,
toMlirNamedAttribute(
"name",
mlirStringAttrGet(context_, toMlirStringRef("onnx.Constant"))),
toMlirNamedAttribute("torch.onnx.value", value_attr));
MlirValue result = mlirOperationGetResult(op, 0);

auto inserted = nv_map_.insert(std::make_pair(name, result));
Expand Down Expand Up @@ -706,6 +719,11 @@ Status NodeImporter::ImportGeneralNode(const onnx::NodeProto &node) {
// Map inputs to values.
std::vector<MlirValue> input_values;
for (auto &input_name : node.input()) {
if (auto inp = graph_info_.graph_viewer().GetConstantInitializer(input_name,
false)) {
ImportInitializer(*inp);
}

auto found_it = nv_map_.find(input_name);
if (found_it == nv_map_.end()) {
std::string msg = "Non topologically produced ONNX node input '";
Expand All @@ -720,7 +738,7 @@ Status NodeImporter::ImportGeneralNode(const onnx::NodeProto &node) {
std::vector<MlirType> output_types;
for (auto &output_name : node.output()) {
const onnx::TypeProto *type_proto =
graph_info_.FindTypeProtoForName(output_name);
graph_info_.graph_viewer().GetNodeArg(output_name)->TypeAsProto();
if (!type_proto)
return failure();

Expand Down Expand Up @@ -955,15 +973,8 @@ Status NodeImporter::ImportConstantOfShapeNode(const onnx::NodeProto &node) {

Status NodeImporter::GetImmediateShapeTensor(const std::string &name,
std::vector<int64_t> &shape) {
auto found_it = graph_info_.initializer_map().find(name);
if (found_it == graph_info_.initializer_map().end()) {
std::string message = "An immediate shape value for '";
message.append(name);
message.append("' was required but it is dynamically produced");
return SetError(std::move(message));
}

const onnx::TensorProto &tp = found_it->second;
const onnx::TensorProto &tp =
*graph_info_.graph_viewer().GetConstantInitializer(name, false);
shape.clear();

// Since this is being interpreted as a shape, we only support some limited
Expand Down Expand Up @@ -1028,7 +1039,7 @@ void NodeImporter::DebugDumpModule() {
fwrite(sr.data, sizeof(char), sr.length, stderr);
};
MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
mlirOpPrintingFlagsEnableDebugInfo(flags, true, false);
mlirOpPrintingFlagsEnableDebugInfo(flags, false, true);
mlirOperationPrintWithFlags(module_op_, flags, callback, nullptr);
mlirOpPrintingFlagsDestroy(flags);
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
// for class members/accessors because canonical protobuf coding presumes
// this kind of style.

#include "core/graph/graph_viewer.h"
#include "mlir-c/IR.h"
#include "onnx/onnx_pb.h"

#include <memory>

Check warning on line 24 in onnxruntime/core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Found C++ system header after other header. Should be: OnnxImporter.h, c system, c++ system, other. [build/include_order] [4] Raw Output: onnxruntime/core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.h:24: Found C++ system header after other header. Should be: OnnxImporter.h, c system, c++ system, other. [build/include_order] [4]
#include <optional>
#include <string_view>
#include <unordered_map>
Expand Down Expand Up @@ -64,8 +66,9 @@
// Accounting for a GraphProto.
class GraphInfo {
public:
GraphInfo(ModelInfo &model_info, const onnx::GraphProto &graph_proto)
: model_info_(model_info), graph_proto_(graph_proto) {}
GraphInfo(const onnxruntime::GraphViewer &gv, ModelInfo &model_info,
const onnx::GraphProto &graph_proto)
: graph_viewer_(gv), model_info_(model_info), graph_proto_(graph_proto) {}
ModelInfo &model_info() { return model_info_; }
const onnx::GraphProto &graph_proto() { return graph_proto_; }

Expand Down Expand Up @@ -101,12 +104,15 @@
return output_map_;
}

const onnxruntime::GraphViewer &graph_viewer() { return graph_viewer_; }

std::unordered_map<std::string_view, const onnx::TensorProto &> &
initializer_map() {
return initializer_map_;
}

private:
const onnxruntime::GraphViewer &graph_viewer_;
ModelInfo &model_info_;
const onnx::GraphProto &graph_proto_;

Expand All @@ -131,7 +137,7 @@
onnx::ModelProto &model_proto() { return model_proto_; }

/// Post-construction, failable initialization.
Status Initialize();
Status Initialize(const onnxruntime::GraphViewer &gv);

GraphInfo &main_graph() { return *main_graph_; }
const std::string &error_message() { return error_message_; }
Expand All @@ -157,6 +163,7 @@
: model_info_(model_info), context_(context) {}

MlirContext context() { return context_; }
MlirType GetNoneType();

/// Converts the TypeProto to an MlirType, returning a null type and
/// setting an error if not possible.
Expand Down Expand Up @@ -208,15 +215,17 @@
/// Imports all nodes topologically. Internally calls FinalizeGraph.
Status ImportAll();

/// Substitutes !torch.none in place of `''` labelled inputs.

Check warning on line 218 in onnxruntime/core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.h

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "labelled" is a misspelling of "labeled" Raw Output: ./onnxruntime/core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.h:218:47: "labelled" is a misspelling of "labeled"
void ImportNoneNode();
/// Import nodes one at a time. Must complete with a call to FinalizeGraph.
Status ImportNode(const onnx::NodeProto &node);
Status ImportInitializer(const onnx::TensorProto &initializer);
Status FinalizeGraph();

void DebugDumpModule();

private:
void PopulateGraphAttrs(MlirOperation container_op);
Status ImportInitializer(const onnx::TensorProto &initializer);
MlirAttribute ImportGeneralAttribute(const onnx::AttributeProto &onnx_attr);

// Special-form nodes.
Expand Down
4 changes: 0 additions & 4 deletions onnxruntime/core/providers/iree/iree_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,6 @@ std::vector<std::unique_ptr<ComputeCapability>> IREEExecutionProvider::GetCapabi
inputs.push_back(nodeArgPtr->Name());
}

for (auto& name : required_initializers) {
inputs.push_back(name);
}

for (auto& nodeArgPtr : graph_viewer.GetOutputs()) {
outputs.push_back(nodeArgPtr->Name());
}
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/providers/shared_library/provider_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ constexpr const char* kMIGraphXExecutionProvider = "MIGraphXExecutionProvider";
constexpr const char* kQnnExecutionProvider = "QNNExecutionProvider";
constexpr const char* kCpuExecutionProvider = "CPUExecutionProvider";
constexpr const char* kAzureExecutionProvider = "AzureExecutionProvider";
constexpr const char* kIreeExecutionProvider = "IreeExecutionProvider";

template <typename T>
using IAllocatorUniquePtr = std::unique_ptr<T, std::function<void(T*)>>;
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1141,7 +1141,7 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
#endif
} else if (type == kIreeExecutionProvider) {
#if USE_IREE
const auto &it = provider_options_map.find(type);
const auto& it = provider_options_map.find(type);
ProviderOptions iree_option_map = ProviderOptions{};
if (it != provider_options_map.end()) {
iree_option_map = it->second;
Expand Down
Loading