From cbfca6bd6f719eebc35d4f9cbbff83f069cf7ab6 Mon Sep 17 00:00:00 2001 From: Gaurav Shukla Date: Mon, 23 Sep 2024 18:52:19 +0530 Subject: [PATCH] Revert "[IREE EP][Importer] Fix IR import for onnx.ConstantOfShape" This reverts commit e6877af41ab177c6dcd0ec2c0575b0f244f3f136. --- .../torch-mlir-import-onnx/OnnxImporter.cpp | 108 +++++------------- .../torch-mlir-import-onnx/OnnxImporter.h | 4 - 2 files changed, 27 insertions(+), 85 deletions(-) diff --git a/onnxruntime/core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.cpp b/onnxruntime/core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.cpp index 3cbc435cac989..0ba67cf33fd4c 100644 --- a/onnxruntime/core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.cpp +++ b/onnxruntime/core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.cpp @@ -638,7 +638,8 @@ Status NodeImporter::ImportAll() { for (const auto &node : nodes) { if (torch_mlir_onnx::failed(ImportNode(node))) { - return failure(); + return SetError("Failed to import node '" + node.name() + + "': " + "(node:\n" + node.DebugString() + "\n)"); } } @@ -727,8 +728,7 @@ Status NodeImporter::ImportGeneralNode(const onnx::NodeProto &node) { if (found_it == nv_map_.end()) { std::string msg = "Non topologically produced ONNX node input '"; msg.append(input_name); - msg.append("': "); - msg.append(node.DebugString()); + msg.append("'"); return SetError(std::move(msg)); } input_values.push_back(found_it->second); @@ -739,9 +739,8 @@ Status NodeImporter::ImportGeneralNode(const onnx::NodeProto &node) { for (auto &output_name : node.output()) { const onnx::TypeProto *type_proto = graph_info_.graph_viewer().GetNodeArg(output_name)->TypeAsProto(); - if (!type_proto) { - return SetError("Failed to obtain TypeProto for tensor"); - } + if (!type_proto) + return failure(); MlirType t = cc_.ConvertTypeProto(*type_proto); if (mlirTypeIsNull(t)) @@ -907,83 +906,38 @@ Status NodeImporter::ImportConstantOfShapeNode(const onnx::NodeProto &node) { return mlirRankedTensorTypeGet(shape.size(), shape.data(), element_type, /*encoding*/ {nullptr}); }; - const bool has_raw_data = tensor_proto.has_raw_data(); MlirAttribute splat_attr = {nullptr}; - size_t out_size; switch (tensor_proto.data_type()) { - case onnx::TensorProto::DataType::TensorProto_DataType_FLOAT: { - const float *data = nullptr; - if (has_raw_data) { - data = graph_info_.GetOptionalRawData(tensor_proto, out_size); - ORT_ENFORCE(data, "GetOptionalRawData() returned null for tensor proto: ", - tensor_proto.DebugString()); - } + case onnx::TensorProto::DataType::TensorProto_DataType_FLOAT: splat_attr = mlirDenseElementsAttrFloatSplatGet( - tensorTypeFor(mlirF32TypeGet(context_)), - has_raw_data ? data[0] : tensor_proto.float_data(0)); + tensorTypeFor(mlirF32TypeGet(context_)), tensor_proto.float_data(0)); break; - } - case onnx::TensorProto::DataType::TensorProto_DataType_INT32: { - const int32_t *data = nullptr; - if (has_raw_data) { - data = graph_info_.GetOptionalRawData(tensor_proto, out_size); - ORT_ENFORCE(data, "GetOptionalRawData() returned null for tensor proto: ", - tensor_proto.DebugString()); - } - splat_attr = mlirDenseElementsAttrInt32SplatGet( + case onnx::TensorProto::DataType::TensorProto_DataType_INT32: + splat_attr = mlirDenseElementsAttrFloatSplatGet( tensorTypeFor(mlirIntegerTypeSignedGet(context_, 32)), - has_raw_data ? data[0] : tensor_proto.int32_data(0)); + tensor_proto.int32_data(0)); break; - } - case onnx::TensorProto::DataType::TensorProto_DataType_INT64: { - const int64_t *data = nullptr; - if (has_raw_data) { - data = graph_info_.GetOptionalRawData(tensor_proto, out_size); - ORT_ENFORCE(data, "GetOptionalRawData() returned null for tensor proto: ", - tensor_proto.DebugString()); - } - splat_attr = mlirDenseElementsAttrInt64SplatGet( + case onnx::TensorProto::DataType::TensorProto_DataType_INT64: + splat_attr = mlirDenseElementsAttrFloatSplatGet( tensorTypeFor(mlirIntegerTypeSignedGet(context_, 64)), - has_raw_data ? data[0] : tensor_proto.int64_data(0)); + tensor_proto.int64_data(0)); break; - } - case onnx::TensorProto::DataType::TensorProto_DataType_DOUBLE: { - const double *data = nullptr; - if (has_raw_data) { - data = graph_info_.GetOptionalRawData(tensor_proto, out_size); - ORT_ENFORCE(data, "GetOptionalRawData() returned null for tensor proto: ", - tensor_proto.DebugString()); - } - splat_attr = mlirDenseElementsAttrDoubleSplatGet( - tensorTypeFor(mlirF64TypeGet(context_)), - has_raw_data ? data[0] : tensor_proto.double_data(0)); + case onnx::TensorProto::DataType::TensorProto_DataType_DOUBLE: + splat_attr = mlirDenseElementsAttrFloatSplatGet( + tensorTypeFor(mlirF64TypeGet(context_)), tensor_proto.double_data(0)); break; - } - case onnx::TensorProto::DataType::TensorProto_DataType_UINT64: { - const uint64_t *data = nullptr; - if (has_raw_data) { - data = graph_info_.GetOptionalRawData(tensor_proto, out_size); - ORT_ENFORCE(data, "GetOptionalRawData() returned null for tensor proto: ", - tensor_proto.DebugString()); - } - splat_attr = mlirDenseElementsAttrUInt64SplatGet( + case onnx::TensorProto::DataType::TensorProto_DataType_UINT64: + splat_attr = mlirDenseElementsAttrFloatSplatGet( tensorTypeFor(mlirIntegerTypeUnsignedGet(context_, 64)), - has_raw_data ? data[0] : tensor_proto.uint64_data(0)); + tensor_proto.uint64_data(0)); break; - } - case onnx::TensorProto::DataType::TensorProto_DataType_UINT32: { - const uint32_t *data = nullptr; - if (has_raw_data) { - data = graph_info_.GetOptionalRawData(tensor_proto, out_size); - ORT_ENFORCE(data, "GetOptionalRawData() returned null for tensor proto: ", - tensor_proto.DebugString()); - } - splat_attr = mlirDenseElementsAttrUInt32SplatGet( + case onnx::TensorProto::DataType::TensorProto_DataType_UINT32: + // Special case: inline data is stored in uint64. + splat_attr = mlirDenseElementsAttrFloatSplatGet( tensorTypeFor(mlirIntegerTypeUnsignedGet(context_, 32)), - has_raw_data ? data[0] : tensor_proto.float_data(0)); + tensor_proto.uint64_data(0)); break; } - } if (mlirAttributeIsNull(splat_attr)) { std::string message = @@ -1004,7 +958,8 @@ Status NodeImporter::ImportConstantOfShapeNode(const onnx::NodeProto &node) { toMlirNamedAttribute("value", splat_attr)); MlirValue result = mlirOperationGetResult(op, 0); - auto inserted = nv_map_.emplace(node.output(0), result); + // Export to the nv_map. + auto inserted = nv_map_.insert(std::make_pair(name, result)); if (!inserted.second) { std::string msg = "Multiple nodes produced a value for '"; msg.append(name); @@ -1018,17 +973,8 @@ Status NodeImporter::ImportConstantOfShapeNode(const onnx::NodeProto &node) { Status NodeImporter::GetImmediateShapeTensor(const std::string &name, std::vector &shape) { - const onnx::TensorProto *tensor = - graph_info_.graph_viewer().GetConstantInitializer(name, false); - if (!tensor) { - std::string msg = "Could not find the immediate shape tensor "; - msg.append(name); - msg.append(" in constant graph initializers. It was possibly produced " - "dynamically."); - return SetError(msg); - } - const onnx::TensorProto &tp = *tensor; - + const onnx::TensorProto &tp = + *graph_info_.graph_viewer().GetConstantInitializer(name, false); shape.clear(); // Since this is being interpreted as a shape, we only support some limited diff --git a/onnxruntime/core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.h b/onnxruntime/core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.h index e916da6b8a7d3..733aefbda2583 100644 --- a/onnxruntime/core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.h +++ b/onnxruntime/core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.h @@ -93,10 +93,6 @@ class GraphInfo { return nullptr; } - std::unordered_map & - value_info_map() { - return value_info_map_; - } std::vector &inputs() { return inputs_; } std::unordered_map & input_map() {