diff --git a/onnxruntime/core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.cpp b/onnxruntime/core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.cpp index 0ba67cf33fd4c..3cbc435cac989 100644 --- a/onnxruntime/core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.cpp +++ b/onnxruntime/core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.cpp @@ -638,8 +638,7 @@ Status NodeImporter::ImportAll() { for (const auto &node : nodes) { if (torch_mlir_onnx::failed(ImportNode(node))) { - return SetError("Failed to import node '" + node.name() + - "': " + "(node:\n" + node.DebugString() + "\n)"); + return failure(); } } @@ -728,7 +727,8 @@ Status NodeImporter::ImportGeneralNode(const onnx::NodeProto &node) { if (found_it == nv_map_.end()) { std::string msg = "Non topologically produced ONNX node input '"; msg.append(input_name); - msg.append("'"); + msg.append("': "); + msg.append(node.DebugString()); return SetError(std::move(msg)); } input_values.push_back(found_it->second); @@ -739,8 +739,9 @@ Status NodeImporter::ImportGeneralNode(const onnx::NodeProto &node) { for (auto &output_name : node.output()) { const onnx::TypeProto *type_proto = graph_info_.graph_viewer().GetNodeArg(output_name)->TypeAsProto(); - if (!type_proto) - return failure(); + if (!type_proto) { + return SetError("Failed to obtain TypeProto for tensor"); + } MlirType t = cc_.ConvertTypeProto(*type_proto); if (mlirTypeIsNull(t)) @@ -906,38 +907,83 @@ Status NodeImporter::ImportConstantOfShapeNode(const onnx::NodeProto &node) { return mlirRankedTensorTypeGet(shape.size(), shape.data(), element_type, /*encoding*/ {nullptr}); }; + const bool has_raw_data = tensor_proto.has_raw_data(); MlirAttribute splat_attr = {nullptr}; + size_t out_size; switch (tensor_proto.data_type()) { - case onnx::TensorProto::DataType::TensorProto_DataType_FLOAT: + case onnx::TensorProto::DataType::TensorProto_DataType_FLOAT: { + const float *data = nullptr; + if (has_raw_data) { + data = graph_info_.GetOptionalRawData(tensor_proto, out_size); + ORT_ENFORCE(data, "GetOptionalRawData() returned null for tensor proto: ", + tensor_proto.DebugString()); + } splat_attr = mlirDenseElementsAttrFloatSplatGet( - tensorTypeFor(mlirF32TypeGet(context_)), tensor_proto.float_data(0)); + tensorTypeFor(mlirF32TypeGet(context_)), + has_raw_data ? data[0] : tensor_proto.float_data(0)); break; - case onnx::TensorProto::DataType::TensorProto_DataType_INT32: - splat_attr = mlirDenseElementsAttrFloatSplatGet( + } + case onnx::TensorProto::DataType::TensorProto_DataType_INT32: { + const int32_t *data = nullptr; + if (has_raw_data) { + data = graph_info_.GetOptionalRawData(tensor_proto, out_size); + ORT_ENFORCE(data, "GetOptionalRawData() returned null for tensor proto: ", + tensor_proto.DebugString()); + } + splat_attr = mlirDenseElementsAttrInt32SplatGet( tensorTypeFor(mlirIntegerTypeSignedGet(context_, 32)), - tensor_proto.int32_data(0)); + has_raw_data ? data[0] : tensor_proto.int32_data(0)); break; - case onnx::TensorProto::DataType::TensorProto_DataType_INT64: - splat_attr = mlirDenseElementsAttrFloatSplatGet( + } + case onnx::TensorProto::DataType::TensorProto_DataType_INT64: { + const int64_t *data = nullptr; + if (has_raw_data) { + data = graph_info_.GetOptionalRawData(tensor_proto, out_size); + ORT_ENFORCE(data, "GetOptionalRawData() returned null for tensor proto: ", + tensor_proto.DebugString()); + } + splat_attr = mlirDenseElementsAttrInt64SplatGet( tensorTypeFor(mlirIntegerTypeSignedGet(context_, 64)), - tensor_proto.int64_data(0)); + has_raw_data ? data[0] : tensor_proto.int64_data(0)); break; - case onnx::TensorProto::DataType::TensorProto_DataType_DOUBLE: - splat_attr = mlirDenseElementsAttrFloatSplatGet( - tensorTypeFor(mlirF64TypeGet(context_)), tensor_proto.double_data(0)); + } + case onnx::TensorProto::DataType::TensorProto_DataType_DOUBLE: { + const double *data = nullptr; + if (has_raw_data) { + data = graph_info_.GetOptionalRawData(tensor_proto, out_size); + ORT_ENFORCE(data, "GetOptionalRawData() returned null for tensor proto: ", + tensor_proto.DebugString()); + } + splat_attr = mlirDenseElementsAttrDoubleSplatGet( + tensorTypeFor(mlirF64TypeGet(context_)), + has_raw_data ? data[0] : tensor_proto.double_data(0)); break; - case onnx::TensorProto::DataType::TensorProto_DataType_UINT64: - splat_attr = mlirDenseElementsAttrFloatSplatGet( + } + case onnx::TensorProto::DataType::TensorProto_DataType_UINT64: { + const uint64_t *data = nullptr; + if (has_raw_data) { + data = graph_info_.GetOptionalRawData(tensor_proto, out_size); + ORT_ENFORCE(data, "GetOptionalRawData() returned null for tensor proto: ", + tensor_proto.DebugString()); + } + splat_attr = mlirDenseElementsAttrUInt64SplatGet( tensorTypeFor(mlirIntegerTypeUnsignedGet(context_, 64)), - tensor_proto.uint64_data(0)); + has_raw_data ? data[0] : tensor_proto.uint64_data(0)); break; - case onnx::TensorProto::DataType::TensorProto_DataType_UINT32: - // Special case: inline data is stored in uint64. - splat_attr = mlirDenseElementsAttrFloatSplatGet( + } + case onnx::TensorProto::DataType::TensorProto_DataType_UINT32: { + const uint32_t *data = nullptr; + if (has_raw_data) { + data = graph_info_.GetOptionalRawData(tensor_proto, out_size); + ORT_ENFORCE(data, "GetOptionalRawData() returned null for tensor proto: ", + tensor_proto.DebugString()); + } + splat_attr = mlirDenseElementsAttrUInt32SplatGet( tensorTypeFor(mlirIntegerTypeUnsignedGet(context_, 32)), - tensor_proto.uint64_data(0)); + has_raw_data ? data[0] : tensor_proto.float_data(0)); break; } + } if (mlirAttributeIsNull(splat_attr)) { std::string message = @@ -958,8 +1004,7 @@ Status NodeImporter::ImportConstantOfShapeNode(const onnx::NodeProto &node) { toMlirNamedAttribute("value", splat_attr)); MlirValue result = mlirOperationGetResult(op, 0); - // Export to the nv_map. - auto inserted = nv_map_.insert(std::make_pair(name, result)); + auto inserted = nv_map_.emplace(node.output(0), result); if (!inserted.second) { std::string msg = "Multiple nodes produced a value for '"; msg.append(name); @@ -973,8 +1018,17 @@ Status NodeImporter::ImportConstantOfShapeNode(const onnx::NodeProto &node) { Status NodeImporter::GetImmediateShapeTensor(const std::string &name, std::vector &shape) { - const onnx::TensorProto &tp = - *graph_info_.graph_viewer().GetConstantInitializer(name, false); + const onnx::TensorProto *tensor = + graph_info_.graph_viewer().GetConstantInitializer(name, false); + if (!tensor) { + std::string msg = "Could not find the immediate shape tensor "; + msg.append(name); + msg.append(" in constant graph initializers. It was possibly produced " + "dynamically."); + return SetError(msg); + } + const onnx::TensorProto &tp = *tensor; + shape.clear(); // Since this is being interpreted as a shape, we only support some limited diff --git a/onnxruntime/core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.h b/onnxruntime/core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.h index 733aefbda2583..e916da6b8a7d3 100644 --- a/onnxruntime/core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.h +++ b/onnxruntime/core/providers/iree/compiler/torch-mlir-import-onnx/OnnxImporter.h @@ -93,6 +93,10 @@ class GraphInfo { return nullptr; } + std::unordered_map & + value_info_map() { + return value_info_map_; + } std::vector &inputs() { return inputs_; } std::unordered_map & input_map() {