Skip to content

Commit

Permalink
[IREE-EP] Do not take initializers as function args in IR
Browse files Browse the repository at this point in the history
  • Loading branch information
vinayakdsci committed Sep 10, 2024
1 parent 94a11e8 commit 82fd6d5
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 45 deletions.
26 changes: 4 additions & 22 deletions onnxruntime/core/providers/iree/compiler/jit_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ CompilerInvocation::~CompilerInvocation() {
ireeCompilerInvocationDestroy(inv);
}

common::Status CompilerInvocation::ImportSubgraph(const ONNX_NAMESPACE::ModelProto &model_proto, const onnxruntime::GraphViewer& graph_view, const std::string& func_name) {
common::Status CompilerInvocation::ImportSubgraph(const onnxruntime::GraphViewer& graph_view, const std::string& func_name) {
// Note that we just use a synthetic top-level ModelProto and forego main
// graph initialization. Since we are operating on a subgraph view, we
// initialize from the backing Graph proto but initialize it ourselves.
Expand All @@ -158,8 +158,6 @@ common::Status CompilerInvocation::ImportSubgraph(const ONNX_NAMESPACE::ModelPro
opset_import->set_version(it.second);
}

// Unforgivably sharp edge: There is a ToGraphProto() that returns a value and another that returns a reference.
// And they differ by const-ness. We need to make sure we get the reference, obviously, so we assign it explicitly.
ONNX_NAMESPACE::GraphProto graph_proto;
GraphViewerToProto(graph_view, graph_proto, true, true);
// LOGS(session.logger, INFO) << " full graph: " << graph_proto.DebugString();
Expand Down Expand Up @@ -195,25 +193,9 @@ common::Status CompilerInvocation::ImportSubgraph(const ONNX_NAMESPACE::ModelPro
model_info.error_message(), ConsumeDiagnostics());
}

// Import each node. Note that the importer uses references internally and expects nodes to be located at fixed
// memory locations for the life of iteration. So we materialize them into a fixed vector first. This is because
// the onnxruntime does not keep the serialized proto form sync'd on its own.
auto node_indices = graph_view.GetNodesInTopologicalOrder();
std::vector<ONNX_NAMESPACE::NodeProto> nodes(node_indices.size());
for (size_t i = 0; i < node_indices.size(); ++i) {
graph_view.GetNode(node_indices[i])->ToProto(nodes[i]);
}
imp.ImportNoneNode();
for (const auto& node : nodes) {
if (torch_mlir_onnx::failed(imp.ImportNode(node))) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to import node '", node.name(), "': ",
model_info.error_message(), " (node:\n", node.DebugString(), "\n)", ConsumeDiagnostics());
}
}

// Finalize.
if (torch_mlir_onnx::failed(imp.FinalizeGraph())) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, model_info.error_message(), ConsumeDiagnostics());
if (torch_mlir_onnx::failed(imp.ImportAll())) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to import nodes", ": ",
model_info.error_message(), ConsumeDiagnostics());
}

// Verify the function at the point of import because we have better diagnostics.
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/iree/compiler/jit_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ struct CompilerInvocation {
~CompilerInvocation();

// Imports a subgraph as a public function.
common::Status ImportSubgraph(const ONNX_NAMESPACE::ModelProto &model_proto, const onnxruntime::GraphViewer& graph_view, const std::string& func_name);
common::Status ImportSubgraph(const onnxruntime::GraphViewer& graph_view, const std::string& func_name);

// Compile and output a VMFB.
common::Status CompileAndOutputVMFB(iree_compiler_output_t* output);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -556,8 +556,8 @@ Status NodeImporter::DefineFunction(std::optional<std::string> name,
std::vector<MlirType> input_types;
std::vector<MlirLocation> input_locs;
std::vector<MlirType> output_types;
for (auto input : graph_info_.graph_proto().input()) {
MlirType t = cc_.ConvertTypeProto(input.type());
for (auto input : graph_info_.graph_viewer().GetInputs()) {
MlirType t = cc_.ConvertTypeProto(*input->TypeAsProto());
if (mlirTypeIsNull(t)) {
return failure();
}
Expand Down Expand Up @@ -588,8 +588,9 @@ Status NodeImporter::DefineFunction(std::optional<std::string> name,
mlirRegionAppendOwnedBlock(bodyRegion, body_block_);

// Map the block args to names and store for evaluation.
for (int i = 0, e = graph_info_.graph_proto().input().size(); i < e; ++i) {
std::string_view name = graph_info_.graph_proto().input()[i].name();
for (int i = 0, e = graph_info_.graph_viewer().GetInputs().size(); i < e;
++i) {
std::string_view name = graph_info_.graph_viewer().GetInputs()[i]->Name();
MlirValue value = mlirBlockGetArgument(body_block_, i);
nv_map_[name] = value;
}
Expand Down Expand Up @@ -655,9 +656,19 @@ Status NodeImporter::ImportAll() {
if (failed(ImportInitializer(it.second)))
return failure();
}
for (auto it : graph_info_.graph_proto().node()) {
if (failed(ImportNode(it)))
return failure();
ImportNoneNode();

auto node_indices = graph_info_.graph_viewer().GetNodesInTopologicalOrder();
std::vector<ONNX_NAMESPACE::NodeProto> nodes(node_indices.size());
for (size_t i = 0; i < node_indices.size(); ++i) {
graph_info_.graph_viewer().GetNode(node_indices[i])->ToProto(nodes[i]);
}

for (const auto &node : nodes) {
if (torch_mlir_onnx::failed(ImportNode(node))) {
return SetError("Failed to import node '" + node.name() +
"': " + "(node:\n" + node.DebugString() + "\n)");
}
}

return FinalizeGraph();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,13 +217,13 @@ class NodeImporter {
void ImportNoneNode();
/// Import nodes one at a time. Must complete with a call to FinalizeGraph.
Status ImportNode(const onnx::NodeProto &node);
Status ImportInitializer(const onnx::TensorProto &initializer);
Status FinalizeGraph();

void DebugDumpModule();

private:
void PopulateGraphAttrs(MlirOperation container_op);
Status ImportInitializer(const onnx::TensorProto &initializer);
MlirAttribute ImportGeneralAttribute(const onnx::AttributeProto &onnx_attr);

// Special-form nodes.
Expand Down
16 changes: 2 additions & 14 deletions onnxruntime/core/providers/iree/iree_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,9 @@
#include "core/framework/compute_capability.h"
#include "core/framework/fallback_cpu_capability.h"
#include "core/framework/kernel_registry.h"
#include "core/graph/graph_proto_serializer.h"
#include "core/graph/graph_utils.h"
#include "core/graph/graph_viewer.h"

#include "core/graph/model.h"
#include "core/providers/iree/compiler/jit_compiler.h"

#include <cassert>
Expand Down Expand Up @@ -124,18 +122,8 @@ common::Status IREEExecutionProvider::Compile(const std::vector<FusedNodeAndGrap
for (auto& fused_node_graph : fused_nodes_and_graphs) {
const GraphViewer& graph_view = fused_node_graph.filtered_graph;
const Node& fused_node = fused_node_graph.fused_node;
const std::string func_name = fused_node.Name();
Model model(graph_view.Name(), true, ModelMetaData(), PathString(),
IOnnxRuntimeOpSchemaRegistryList(), graph_view.DomainToVersionMap(),
std::vector<ONNX_NAMESPACE::FunctionProto>(), *GetLogger());
ONNX_NAMESPACE::ModelProto model_proto = model.ToProto();

GraphViewerToProto(graph_view, *model_proto.mutable_graph(), true, true);
auto opset = model_proto.add_opset_import();
opset->set_domain(kOnnxDomain);
opset->set_version(graph_view.DomainToVersionMap().at(kOnnxDomain));

ORT_RETURN_IF_ERROR(inv.ImportSubgraph(model_proto, graph_view, func_name));
const std::string& func_name = fused_node.Name();
ORT_RETURN_IF_ERROR(inv.ImportSubgraph(graph_view, func_name));
// The fully qualified name is the {module_name}.{func_name}. This is what we look up at
// runtime.
std::string fq_name(module_name);
Expand Down

0 comments on commit 82fd6d5

Please sign in to comment.