Skip to content

Commit

Permalink
Revert "[IREE EP][Importer] Fix IR import for onnx.ConstantOfShape"
Browse files Browse the repository at this point in the history
This reverts commit e6877af.
  • Loading branch information
Shukla-Gaurav authored Sep 23, 2024
1 parent e6877af commit cbfca6b
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 85 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,8 @@ Status NodeImporter::ImportAll() {

for (const auto &node : nodes) {
if (torch_mlir_onnx::failed(ImportNode(node))) {
return failure();
return SetError("Failed to import node '" + node.name() +
"': " + "(node:\n" + node.DebugString() + "\n)");
}
}

Expand Down Expand Up @@ -727,8 +728,7 @@ Status NodeImporter::ImportGeneralNode(const onnx::NodeProto &node) {
if (found_it == nv_map_.end()) {
std::string msg = "Non topologically produced ONNX node input '";
msg.append(input_name);
msg.append("': ");
msg.append(node.DebugString());
msg.append("'");
return SetError(std::move(msg));
}
input_values.push_back(found_it->second);
Expand All @@ -739,9 +739,8 @@ Status NodeImporter::ImportGeneralNode(const onnx::NodeProto &node) {
for (auto &output_name : node.output()) {
const onnx::TypeProto *type_proto =
graph_info_.graph_viewer().GetNodeArg(output_name)->TypeAsProto();
if (!type_proto) {
return SetError("Failed to obtain TypeProto for tensor");
}
if (!type_proto)
return failure();

MlirType t = cc_.ConvertTypeProto(*type_proto);
if (mlirTypeIsNull(t))
Expand Down Expand Up @@ -907,83 +906,38 @@ Status NodeImporter::ImportConstantOfShapeNode(const onnx::NodeProto &node) {
return mlirRankedTensorTypeGet(shape.size(), shape.data(), element_type,
/*encoding*/ {nullptr});
};
const bool has_raw_data = tensor_proto.has_raw_data();
MlirAttribute splat_attr = {nullptr};
size_t out_size;
switch (tensor_proto.data_type()) {
case onnx::TensorProto::DataType::TensorProto_DataType_FLOAT: {
const float *data = nullptr;
if (has_raw_data) {
data = graph_info_.GetOptionalRawData<float>(tensor_proto, out_size);
ORT_ENFORCE(data, "GetOptionalRawData() returned null for tensor proto: ",
tensor_proto.DebugString());
}
case onnx::TensorProto::DataType::TensorProto_DataType_FLOAT:
splat_attr = mlirDenseElementsAttrFloatSplatGet(
tensorTypeFor(mlirF32TypeGet(context_)),
has_raw_data ? data[0] : tensor_proto.float_data(0));
tensorTypeFor(mlirF32TypeGet(context_)), tensor_proto.float_data(0));
break;
}
case onnx::TensorProto::DataType::TensorProto_DataType_INT32: {
const int32_t *data = nullptr;
if (has_raw_data) {
data = graph_info_.GetOptionalRawData<int32_t>(tensor_proto, out_size);
ORT_ENFORCE(data, "GetOptionalRawData() returned null for tensor proto: ",
tensor_proto.DebugString());
}
splat_attr = mlirDenseElementsAttrInt32SplatGet(
case onnx::TensorProto::DataType::TensorProto_DataType_INT32:
splat_attr = mlirDenseElementsAttrFloatSplatGet(
tensorTypeFor(mlirIntegerTypeSignedGet(context_, 32)),
has_raw_data ? data[0] : tensor_proto.int32_data(0));
tensor_proto.int32_data(0));
break;
}
case onnx::TensorProto::DataType::TensorProto_DataType_INT64: {
const int64_t *data = nullptr;
if (has_raw_data) {
data = graph_info_.GetOptionalRawData<int64_t>(tensor_proto, out_size);
ORT_ENFORCE(data, "GetOptionalRawData() returned null for tensor proto: ",
tensor_proto.DebugString());
}
splat_attr = mlirDenseElementsAttrInt64SplatGet(
case onnx::TensorProto::DataType::TensorProto_DataType_INT64:
splat_attr = mlirDenseElementsAttrFloatSplatGet(
tensorTypeFor(mlirIntegerTypeSignedGet(context_, 64)),
has_raw_data ? data[0] : tensor_proto.int64_data(0));
tensor_proto.int64_data(0));
break;
}
case onnx::TensorProto::DataType::TensorProto_DataType_DOUBLE: {
const double *data = nullptr;
if (has_raw_data) {
data = graph_info_.GetOptionalRawData<double>(tensor_proto, out_size);
ORT_ENFORCE(data, "GetOptionalRawData() returned null for tensor proto: ",
tensor_proto.DebugString());
}
splat_attr = mlirDenseElementsAttrDoubleSplatGet(
tensorTypeFor(mlirF64TypeGet(context_)),
has_raw_data ? data[0] : tensor_proto.double_data(0));
case onnx::TensorProto::DataType::TensorProto_DataType_DOUBLE:
splat_attr = mlirDenseElementsAttrFloatSplatGet(
tensorTypeFor(mlirF64TypeGet(context_)), tensor_proto.double_data(0));
break;
}
case onnx::TensorProto::DataType::TensorProto_DataType_UINT64: {
const uint64_t *data = nullptr;
if (has_raw_data) {
data = graph_info_.GetOptionalRawData<uint64_t>(tensor_proto, out_size);
ORT_ENFORCE(data, "GetOptionalRawData() returned null for tensor proto: ",
tensor_proto.DebugString());
}
splat_attr = mlirDenseElementsAttrUInt64SplatGet(
case onnx::TensorProto::DataType::TensorProto_DataType_UINT64:
splat_attr = mlirDenseElementsAttrFloatSplatGet(
tensorTypeFor(mlirIntegerTypeUnsignedGet(context_, 64)),
has_raw_data ? data[0] : tensor_proto.uint64_data(0));
tensor_proto.uint64_data(0));
break;
}
case onnx::TensorProto::DataType::TensorProto_DataType_UINT32: {
const uint32_t *data = nullptr;
if (has_raw_data) {
data = graph_info_.GetOptionalRawData<uint32_t>(tensor_proto, out_size);
ORT_ENFORCE(data, "GetOptionalRawData() returned null for tensor proto: ",
tensor_proto.DebugString());
}
splat_attr = mlirDenseElementsAttrUInt32SplatGet(
case onnx::TensorProto::DataType::TensorProto_DataType_UINT32:
// Special case: inline data is stored in uint64.
splat_attr = mlirDenseElementsAttrFloatSplatGet(
tensorTypeFor(mlirIntegerTypeUnsignedGet(context_, 32)),
has_raw_data ? data[0] : tensor_proto.float_data(0));
tensor_proto.uint64_data(0));
break;
}
}

if (mlirAttributeIsNull(splat_attr)) {
std::string message =
Expand All @@ -1004,7 +958,8 @@ Status NodeImporter::ImportConstantOfShapeNode(const onnx::NodeProto &node) {
toMlirNamedAttribute("value", splat_attr));
MlirValue result = mlirOperationGetResult(op, 0);

auto inserted = nv_map_.emplace(node.output(0), result);
// Export to the nv_map.
auto inserted = nv_map_.insert(std::make_pair(name, result));
if (!inserted.second) {
std::string msg = "Multiple nodes produced a value for '";
msg.append(name);
Expand All @@ -1018,17 +973,8 @@ Status NodeImporter::ImportConstantOfShapeNode(const onnx::NodeProto &node) {

Status NodeImporter::GetImmediateShapeTensor(const std::string &name,
std::vector<int64_t> &shape) {
const onnx::TensorProto *tensor =
graph_info_.graph_viewer().GetConstantInitializer(name, false);
if (!tensor) {
std::string msg = "Could not find the immediate shape tensor ";
msg.append(name);
msg.append(" in constant graph initializers. It was possibly produced "
"dynamically.");
return SetError(msg);
}
const onnx::TensorProto &tp = *tensor;

const onnx::TensorProto &tp =
*graph_info_.graph_viewer().GetConstantInitializer(name, false);
shape.clear();

// Since this is being interpreted as a shape, we only support some limited
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,6 @@ class GraphInfo {
return nullptr;
}

std::unordered_map<std::string_view, const onnx::ValueInfoProto &> &
value_info_map() {
return value_info_map_;
}
std::vector<const onnx::ValueInfoProto *> &inputs() { return inputs_; }
std::unordered_map<std::string_view, const onnx::ValueInfoProto &> &
input_map() {
Expand Down

0 comments on commit cbfca6b

Please sign in to comment.