Skip to content

Commit

Permalink
[IREE EP][Importer] Fix ElementsAttr iteration error
Browse files Browse the repository at this point in the history
  • Loading branch information
vinayakdsci committed Sep 19, 2024
1 parent 82fd6d5 commit 8c17229
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 29 deletions.
8 changes: 8 additions & 0 deletions onnxruntime/core/providers/get_execution_providers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,14 @@ constexpr ProviderInfo kProvidersInPriorityOrder[] =
true,
#else
false,
#endif
},
{
kIreeExecutionProvider,
#ifdef USE_IREE
true,
#else
false,
#endif
},
{kCpuExecutionProvider, true}, // kCpuExecutionProvider is always last
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/iree/compiler/jit_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ common::Status CompilerInvocation::ImportSubgraph(const onnxruntime::GraphViewer
}

ONNX_NAMESPACE::GraphProto graph_proto;
GraphViewerToProto(graph_view, graph_proto, true, true);
GraphViewerToProto(graph_view, graph_proto, false, false);
// LOGS(session.logger, INFO) << " full graph: " << graph_proto.DebugString();

// Set up for subgraph import.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,11 +246,6 @@ const onnx::TypeProto *GraphInfo::FindTypeProtoForName(std::string_view name) {
}
#endif
return graph_viewer_.GetNodeArg(std::string{name})->TypeAsProto();
// std::string msg = "No type information associated with '";
// msg.append(name);
// msg.append("'. Run shape inference?");
// model_info_.SetError(std::move(msg));
// return nullptr;
}

// ---------------------------------------------------------------------------//
Expand Down Expand Up @@ -556,7 +551,7 @@ Status NodeImporter::DefineFunction(std::optional<std::string> name,
std::vector<MlirType> input_types;
std::vector<MlirLocation> input_locs;
std::vector<MlirType> output_types;
for (auto input : graph_info_.graph_viewer().GetInputs()) {
for (auto *input : graph_info_.graph_viewer().GetInputs()) {
MlirType t = cc_.ConvertTypeProto(*input->TypeAsProto());
if (mlirTypeIsNull(t)) {
return failure();
Expand Down Expand Up @@ -650,12 +645,6 @@ void NodeImporter::PopulateGraphAttrs(MlirOperation container_op) {
}

Status NodeImporter::ImportAll() {
// TODO: Consider pulling in initializers on demand since there can be so
// much unused crap.
for (auto it : graph_info_.initializer_map()) {
if (failed(ImportInitializer(it.second)))
return failure();
}
ImportNoneNode();

auto node_indices = graph_info_.graph_viewer().GetNodesInTopologicalOrder();
Expand Down Expand Up @@ -708,8 +697,11 @@ Status NodeImporter::ImportInitializer(const onnx::TensorProto &initializer) {
return failure();

MlirOperation op = createMlirOperationAtEnd(
body_block_, "torch.vtensor.literal", loc, vtensor_type,
toMlirNamedAttribute("value", value_attr));
body_block_, "torch.operator", loc, vtensor_type,
toMlirNamedAttribute(
"name",
mlirStringAttrGet(context_, toMlirStringRef("onnx.Constant"))),
toMlirNamedAttribute("torch.onnx.value", value_attr));
MlirValue result = mlirOperationGetResult(op, 0);

auto inserted = nv_map_.insert(std::make_pair(name, result));
Expand Down Expand Up @@ -744,6 +736,11 @@ Status NodeImporter::ImportGeneralNode(const onnx::NodeProto &node) {
// Map inputs to values.
std::vector<MlirValue> input_values;
for (auto &input_name : node.input()) {
if (auto inp = graph_info_.graph_viewer().GetConstantInitializer(input_name,
false)) {
ImportInitializer(*inp);
}

auto found_it = nv_map_.find(input_name);
if (found_it == nv_map_.end()) {
std::string msg = "Non topologically produced ONNX node input '";
Expand Down Expand Up @@ -993,15 +990,8 @@ Status NodeImporter::ImportConstantOfShapeNode(const onnx::NodeProto &node) {

Status NodeImporter::GetImmediateShapeTensor(const std::string &name,
std::vector<int64_t> &shape) {
auto found_it = graph_info_.initializer_map().find(name);
if (found_it == graph_info_.initializer_map().end()) {
std::string message = "An immediate shape value for '";
message.append(name);
message.append("' was required but it is dynamically produced");
return SetError(std::move(message));
}

const onnx::TensorProto &tp = found_it->second;
const onnx::TensorProto &tp =
*graph_info_.graph_viewer().GetConstantInitializer(name, false);
shape.clear();

// Since this is being interpreted as a shape, we only support some limited
Expand Down Expand Up @@ -1066,7 +1056,7 @@ void NodeImporter::DebugDumpModule() {
fwrite(sr.data, sizeof(char), sr.length, stderr);
};
MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
mlirOpPrintingFlagsEnableDebugInfo(flags, true, false);
mlirOpPrintingFlagsEnableDebugInfo(flags, false, true);
mlirOperationPrintWithFlags(module_op_, flags, callback, nullptr);
mlirOpPrintingFlagsDestroy(flags);
}
4 changes: 0 additions & 4 deletions onnxruntime/core/providers/iree/iree_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,6 @@ std::vector<std::unique_ptr<ComputeCapability>> IREEExecutionProvider::GetCapabi
inputs.push_back(nodeArgPtr->Name());
}

for (auto& name : required_initializers) {
inputs.push_back(name);
}

for (auto& nodeArgPtr : graph_viewer.GetOutputs()) {
outputs.push_back(nodeArgPtr->Name());
}
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/providers/shared_library/provider_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ constexpr const char* kMIGraphXExecutionProvider = "MIGraphXExecutionProvider";
constexpr const char* kQnnExecutionProvider = "QNNExecutionProvider";
constexpr const char* kCpuExecutionProvider = "CPUExecutionProvider";
constexpr const char* kAzureExecutionProvider = "AzureExecutionProvider";
constexpr const char* kIreeExecutionProvider = "IreeExecutionProvider";

template <typename T>
using IAllocatorUniquePtr = std::unique_ptr<T, std::function<void(T*)>>;
Expand Down

0 comments on commit 8c17229

Please sign in to comment.