diff --git a/onnxruntime/core/providers/get_execution_providers.cc b/onnxruntime/core/providers/get_execution_providers.cc index 61c035bc29ed5..0f2f26f5ef8c9 100644 --- a/onnxruntime/core/providers/get_execution_providers.cc +++ b/onnxruntime/core/providers/get_execution_providers.cc @@ -186,6 +186,14 @@ constexpr ProviderInfo kProvidersInPriorityOrder[] = true, #else false, +#endif + }, + { + kIreeExecutionProvider, +#ifdef USE_IREE + true, +#else + false, #endif }, {kCpuExecutionProvider, true}, // kCpuExecutionProvider is always last diff --git a/onnxruntime/core/providers/iree/compiler/jit_compiler.cc b/onnxruntime/core/providers/iree/compiler/jit_compiler.cc index 159a2338b1490..cdb2594a9dab9 100644 --- a/onnxruntime/core/providers/iree/compiler/jit_compiler.cc +++ b/onnxruntime/core/providers/iree/compiler/jit_compiler.cc @@ -159,7 +159,7 @@ common::Status CompilerInvocation::ImportSubgraph(const onnxruntime::GraphViewer } ONNX_NAMESPACE::GraphProto graph_proto; - GraphViewerToProto(graph_view, graph_proto, true, true); + GraphViewerToProto(graph_view, graph_proto, false, false); // LOGS(session.logger, INFO) << " full graph: " << graph_proto.DebugString(); // Set up for subgraph import. diff --git a/onnxruntime/core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.cpp b/onnxruntime/core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.cpp index 04fad20dc8727..e28c0844e6845 100644 --- a/onnxruntime/core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.cpp +++ b/onnxruntime/core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.cpp @@ -246,11 +246,6 @@ const onnx::TypeProto *GraphInfo::FindTypeProtoForName(std::string_view name) { } #endif return graph_viewer_.GetNodeArg(std::string{name})->TypeAsProto(); - // std::string msg = "No type information associated with '"; - // msg.append(name); - // msg.append("'. Run shape inference?"); - // model_info_.SetError(std::move(msg)); - // return nullptr; } // ---------------------------------------------------------------------------// @@ -556,7 +551,7 @@ Status NodeImporter::DefineFunction(std::optional name, std::vector input_types; std::vector input_locs; std::vector output_types; - for (auto input : graph_info_.graph_viewer().GetInputs()) { + for (auto *input : graph_info_.graph_viewer().GetInputs()) { MlirType t = cc_.ConvertTypeProto(*input->TypeAsProto()); if (mlirTypeIsNull(t)) { return failure(); @@ -650,12 +645,6 @@ void NodeImporter::PopulateGraphAttrs(MlirOperation container_op) { } Status NodeImporter::ImportAll() { - // TODO: Consider pulling in initializers on demand since there can be so - // much unused crap. - for (auto it : graph_info_.initializer_map()) { - if (failed(ImportInitializer(it.second))) - return failure(); - } ImportNoneNode(); auto node_indices = graph_info_.graph_viewer().GetNodesInTopologicalOrder(); @@ -708,8 +697,11 @@ Status NodeImporter::ImportInitializer(const onnx::TensorProto &initializer) { return failure(); MlirOperation op = createMlirOperationAtEnd( - body_block_, "torch.vtensor.literal", loc, vtensor_type, - toMlirNamedAttribute("value", value_attr)); + body_block_, "torch.operator", loc, vtensor_type, + toMlirNamedAttribute( + "name", + mlirStringAttrGet(context_, toMlirStringRef("onnx.Constant"))), + toMlirNamedAttribute("torch.onnx.value", value_attr)); MlirValue result = mlirOperationGetResult(op, 0); auto inserted = nv_map_.insert(std::make_pair(name, result)); @@ -744,6 +736,11 @@ Status NodeImporter::ImportGeneralNode(const onnx::NodeProto &node) { // Map inputs to values. std::vector input_values; for (auto &input_name : node.input()) { + if (auto inp = graph_info_.graph_viewer().GetConstantInitializer(input_name, + false)) { + ImportInitializer(*inp); + } + auto found_it = nv_map_.find(input_name); if (found_it == nv_map_.end()) { std::string msg = "Non topologically produced ONNX node input '"; @@ -993,15 +990,8 @@ Status NodeImporter::ImportConstantOfShapeNode(const onnx::NodeProto &node) { Status NodeImporter::GetImmediateShapeTensor(const std::string &name, std::vector &shape) { - auto found_it = graph_info_.initializer_map().find(name); - if (found_it == graph_info_.initializer_map().end()) { - std::string message = "An immediate shape value for '"; - message.append(name); - message.append("' was required but it is dynamically produced"); - return SetError(std::move(message)); - } - - const onnx::TensorProto &tp = found_it->second; + const onnx::TensorProto &tp = + *graph_info_.graph_viewer().GetConstantInitializer(name, false); shape.clear(); // Since this is being interpreted as a shape, we only support some limited @@ -1066,7 +1056,7 @@ void NodeImporter::DebugDumpModule() { fwrite(sr.data, sizeof(char), sr.length, stderr); }; MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); - mlirOpPrintingFlagsEnableDebugInfo(flags, true, false); + mlirOpPrintingFlagsEnableDebugInfo(flags, false, true); mlirOperationPrintWithFlags(module_op_, flags, callback, nullptr); mlirOpPrintingFlagsDestroy(flags); } diff --git a/onnxruntime/core/providers/iree/iree_execution_provider.cc b/onnxruntime/core/providers/iree/iree_execution_provider.cc index 70094efa23788..772475224deaa 100644 --- a/onnxruntime/core/providers/iree/iree_execution_provider.cc +++ b/onnxruntime/core/providers/iree/iree_execution_provider.cc @@ -76,10 +76,6 @@ std::vector> IREEExecutionProvider::GetCapabi inputs.push_back(nodeArgPtr->Name()); } - for (auto& name : required_initializers) { - inputs.push_back(name); - } - for (auto& nodeArgPtr : graph_viewer.GetOutputs()) { outputs.push_back(nodeArgPtr->Name()); } diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index b84825236a453..54ad641983006 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -274,6 +274,7 @@ constexpr const char* kMIGraphXExecutionProvider = "MIGraphXExecutionProvider"; constexpr const char* kQnnExecutionProvider = "QNNExecutionProvider"; constexpr const char* kCpuExecutionProvider = "CPUExecutionProvider"; constexpr const char* kAzureExecutionProvider = "AzureExecutionProvider"; +constexpr const char* kIreeExecutionProvider = "IreeExecutionProvider"; template using IAllocatorUniquePtr = std::unique_ptr>;