Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IREE EP][Importer] Fix IR import for onnx.ConstantOfShape #11

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -638,8 +638,7 @@ Status NodeImporter::ImportAll() {

for (const auto &node : nodes) {
if (torch_mlir_onnx::failed(ImportNode(node))) {
return SetError("Failed to import node '" + node.name() +
"': " + "(node:\n" + node.DebugString() + "\n)");
return failure();
}
}

Expand Down Expand Up @@ -728,7 +727,8 @@ Status NodeImporter::ImportGeneralNode(const onnx::NodeProto &node) {
if (found_it == nv_map_.end()) {
std::string msg = "Non topologically produced ONNX node input '";
msg.append(input_name);
msg.append("'");
msg.append("': ");
msg.append(node.DebugString());
return SetError(std::move(msg));
}
input_values.push_back(found_it->second);
Expand All @@ -739,8 +739,9 @@ Status NodeImporter::ImportGeneralNode(const onnx::NodeProto &node) {
for (auto &output_name : node.output()) {
const onnx::TypeProto *type_proto =
graph_info_.graph_viewer().GetNodeArg(output_name)->TypeAsProto();
if (!type_proto)
return failure();
if (!type_proto) {
return SetError("Failed to obtain TypeProto for tensor");
}

MlirType t = cc_.ConvertTypeProto(*type_proto);
if (mlirTypeIsNull(t))
Expand Down Expand Up @@ -906,38 +907,77 @@ Status NodeImporter::ImportConstantOfShapeNode(const onnx::NodeProto &node) {
return mlirRankedTensorTypeGet(shape.size(), shape.data(), element_type,
/*encoding*/ {nullptr});
};
const bool has_raw_data = tensor_proto.has_raw_data();
MlirAttribute splat_attr = {nullptr};
size_t out_size;
switch (tensor_proto.data_type()) {
case onnx::TensorProto::DataType::TensorProto_DataType_FLOAT:
case onnx::TensorProto::DataType::TensorProto_DataType_FLOAT: {
const float *data = {0};
vinayakdsci marked this conversation as resolved.
Show resolved Hide resolved
if (has_raw_data) {
data = graph_info_.GetOptionalRawData<float>(tensor_proto, out_size);
ORT_ENFORCE(data);
vinayakdsci marked this conversation as resolved.
Show resolved Hide resolved
vinayakdsci marked this conversation as resolved.
Show resolved Hide resolved
}
splat_attr = mlirDenseElementsAttrFloatSplatGet(
tensorTypeFor(mlirF32TypeGet(context_)), tensor_proto.float_data(0));
tensorTypeFor(mlirF32TypeGet(context_)),
has_raw_data ? data[0] : tensor_proto.float_data(0));
break;
case onnx::TensorProto::DataType::TensorProto_DataType_INT32:
splat_attr = mlirDenseElementsAttrFloatSplatGet(
}
case onnx::TensorProto::DataType::TensorProto_DataType_INT32: {
const int32_t *data = {0};
if (has_raw_data) {
data = graph_info_.GetOptionalRawData<int32_t>(tensor_proto, out_size);
ORT_ENFORCE(data);
}
splat_attr = mlirDenseElementsAttrInt32SplatGet(
tensorTypeFor(mlirIntegerTypeSignedGet(context_, 32)),
tensor_proto.int32_data(0));
has_raw_data ? data[0] : tensor_proto.int32_data(0));
break;
case onnx::TensorProto::DataType::TensorProto_DataType_INT64:
splat_attr = mlirDenseElementsAttrFloatSplatGet(
}
case onnx::TensorProto::DataType::TensorProto_DataType_INT64: {
const int64_t *data = {0};
if (has_raw_data) {
data = graph_info_.GetOptionalRawData<int64_t>(tensor_proto, out_size);
ORT_ENFORCE(data);
}
splat_attr = mlirDenseElementsAttrInt64SplatGet(
tensorTypeFor(mlirIntegerTypeSignedGet(context_, 64)),
tensor_proto.int64_data(0));
has_raw_data ? data[0] : tensor_proto.int64_data(0));
break;
case onnx::TensorProto::DataType::TensorProto_DataType_DOUBLE:
splat_attr = mlirDenseElementsAttrFloatSplatGet(
tensorTypeFor(mlirF64TypeGet(context_)), tensor_proto.double_data(0));
}
case onnx::TensorProto::DataType::TensorProto_DataType_DOUBLE: {
const double *data = {0};
if (has_raw_data) {
data = graph_info_.GetOptionalRawData<double>(tensor_proto, out_size);
ORT_ENFORCE(data);
}
splat_attr = mlirDenseElementsAttrDoubleSplatGet(
tensorTypeFor(mlirF64TypeGet(context_)),
has_raw_data ? data[0] : tensor_proto.double_data(0));
break;
case onnx::TensorProto::DataType::TensorProto_DataType_UINT64:
splat_attr = mlirDenseElementsAttrFloatSplatGet(
}
case onnx::TensorProto::DataType::TensorProto_DataType_UINT64: {
const uint64_t *data = {0};
if (has_raw_data) {
data = graph_info_.GetOptionalRawData<uint64_t>(tensor_proto, out_size);
ORT_ENFORCE(data);
}
splat_attr = mlirDenseElementsAttrUInt64SplatGet(
tensorTypeFor(mlirIntegerTypeUnsignedGet(context_, 64)),
tensor_proto.uint64_data(0));
has_raw_data ? data[0] : tensor_proto.uint64_data(0));
break;
case onnx::TensorProto::DataType::TensorProto_DataType_UINT32:
// Special case: inline data is stored in uint64.
splat_attr = mlirDenseElementsAttrFloatSplatGet(
}
case onnx::TensorProto::DataType::TensorProto_DataType_UINT32: {
const uint32_t *data = {0};
if (has_raw_data) {
data = graph_info_.GetOptionalRawData<uint32_t>(tensor_proto, out_size);
ORT_ENFORCE(data);
}
splat_attr = mlirDenseElementsAttrUInt32SplatGet(
tensorTypeFor(mlirIntegerTypeUnsignedGet(context_, 32)),
tensor_proto.uint64_data(0));
has_raw_data ? data[0] : tensor_proto.float_data(0));
break;
}
}

if (mlirAttributeIsNull(splat_attr)) {
std::string message =
Expand All @@ -958,8 +998,7 @@ Status NodeImporter::ImportConstantOfShapeNode(const onnx::NodeProto &node) {
toMlirNamedAttribute("value", splat_attr));
MlirValue result = mlirOperationGetResult(op, 0);

// Export to the nv_map.
auto inserted = nv_map_.insert(std::make_pair(name, result));
auto inserted = nv_map_.emplace(node.output(0), result);
if (!inserted.second) {
std::string msg = "Multiple nodes produced a value for '";
msg.append(name);
Expand All @@ -973,8 +1012,17 @@ Status NodeImporter::ImportConstantOfShapeNode(const onnx::NodeProto &node) {

Status NodeImporter::GetImmediateShapeTensor(const std::string &name,
std::vector<int64_t> &shape) {
const onnx::TensorProto &tp =
*graph_info_.graph_viewer().GetConstantInitializer(name, false);
const onnx::TensorProto *tensor =
graph_info_.graph_viewer().GetConstantInitializer(name, false);
if (!tensor) {
std::string msg = "Could not find the immediate shape tensor ";
msg.append(name);
msg.append(" in constant graph initializers. It was possibly produced "
"dynamically.");
return SetError(msg);
}
const onnx::TensorProto &tp = *tensor;

shape.clear();

// Since this is being interpreted as a shape, we only support some limited
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ class GraphInfo {
return nullptr;
}

std::unordered_map<std::string_view, const onnx::ValueInfoProto &> &
value_info_map() {
return value_info_map_;
}
std::vector<const onnx::ValueInfoProto *> &inputs() { return inputs_; }
std::unordered_map<std::string_view, const onnx::ValueInfoProto &> &
input_map() {
Expand Down
Loading