Skip to content

Commit

Permalink
Fix some more builds/tests
Browse files Browse the repository at this point in the history
  • Loading branch information
skottmckay committed Dec 24, 2024
1 parent 275f762 commit 002b6cc
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 72 deletions.
8 changes: 4 additions & 4 deletions onnxruntime/core/framework/tensor_type_and_shape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ ORT_API_STATUS_IMPL(OrtApis::SetTensorElementType, _Inout_ OrtTensorTypeAndShape
ORT_API_STATUS_IMPL(OrtApis::SetDimensions, OrtTensorTypeAndShapeInfo* info,
_In_ const int64_t* dim_values, size_t dim_count) {
API_IMPL_BEGIN
if (std::any_of(dim_values, dim_values + dim_count, [](int64_t v) { return v < -1; })) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "dim_values must be -1 (symbolic dimension) or larger.");
}

auto num_dims = std::max(dim_count, info->dim_params.size());

// make shape and dim_values consistent
Expand Down Expand Up @@ -84,10 +88,6 @@ ORT_API_STATUS_IMPL(OrtApis::GetDimensionsCount, _In_ const struct OrtTensorType

ORT_API_STATUS_IMPL(OrtApis::GetDimensions, _In_ const struct OrtTensorTypeAndShapeInfo* info,
_Out_ int64_t* dim_values, size_t dim_values_length) {
if (std::any_of(dim_values, dim_values + dim_values_length, [](int64_t v) { return v < -1; })) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "dim_values must be -1 (symbolic dimension) or larger.");
}

info->shape.CopyDims(dim_values, dim_values_length);
return nullptr;
}
Expand Down
61 changes: 30 additions & 31 deletions onnxruntime/core/graph/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,36 @@ Status Model::Load(int fd, const PathString& model_path, std::shared_ptr<Model>&
return Status::OK();
}

// static
common::Status Model::LoadFromModelBuilderApiModel(const OrtModel& graph_api_model,
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
const ModelOptions& options,
const logging::Logger& logger,
std::unique_ptr<Model>& model) {
model = std::make_unique<Model>();
model->model_proto_.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
// The optimizer Initializer class requires a path if external data is used, however in the Graph API usage the
// external data is pointing to pre-allocated memory and does not require a path. Set a dummy value to make it happy.
model->model_path_ = std::filesystem::path("_GRAPH_API_MODEL_");

auto schema_registry = std::make_shared<SchemaRegistryManager>();
if (local_registries != nullptr) {
for (const auto& schema_collection : *local_registries) {
schema_registry->RegisterRegistry(schema_collection);
}
}

ORT_RETURN_IF_ERROR(Graph::LoadFromModelBuilderApiModel(*graph_api_model.graph,
*model,
graph_api_model.domain_to_version,
schema_registry,
options.strict_shape_type_inference,
logger,
model->graph_));

return Status::OK();
}

Status Model::Save(Model& model, int p_fd) {
if (p_fd < 0) {
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "<p_fd> is less than 0.");
Expand Down Expand Up @@ -918,35 +948,4 @@ common::Status Model::LoadFromOrtFormat(const fbs::Model& fbs_model,
#endif
return Status::OK();
}

// static
common::Status Model::LoadFromModelBuilderApiModel(const OrtModel& graph_api_model,
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
const ModelOptions& options,
const logging::Logger& logger,
std::unique_ptr<Model>& model) {
model = std::make_unique<Model>();
model->model_proto_.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
// The optimizer Initializer class requires a path if external data is used, however in the Graph API usage the
// external data is pointing to pre-allocated memory and does not require a path. Set a dummy value to make it happy.
model->model_path_ = std::filesystem::path("_GRAPH_API_MODEL_");

auto schema_registry = std::make_shared<SchemaRegistryManager>();
if (local_registries != nullptr) {
for (const auto& schema_collection : *local_registries) {
schema_registry->RegisterRegistry(schema_collection);
}
}

ORT_RETURN_IF_ERROR(Graph::LoadFromModelBuilderApiModel(*graph_api_model.graph,
*model,
graph_api_model.domain_to_version,
schema_registry,
options.strict_shape_type_inference,
logger,
model->graph_));

return Status::OK();
}

} // namespace onnxruntime
93 changes: 65 additions & 28 deletions onnxruntime/test/shared_lib/test_model_builder_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "gtest/gtest.h"
#include "gmock/gmock.h"

#include "core/common/narrow.h"
#include "core/graph/constants.h"
#include "core/session/onnxruntime_c_api.h"
#include "core/session/onnxruntime_cxx_api.h"
Expand Down Expand Up @@ -75,6 +76,19 @@ OrtNode* CreateNode(const OrtModelBuilderApi& api,
&node));
return node;
}

// convenience func to convert initalizer lists to gsl::span
OrtNode* CreateNode(const OrtModelBuilderApi& api,
const char* operator_name, const char* node_name,
const std::initializer_list<const char*> input_names,
const std::initializer_list<const char*> output_names,
const std::initializer_list<OrtOpAttr*> attributes = {},
const char* domain_name = onnxruntime::kOnnxDomain) {
std::vector<const char*> inputs(input_names);
std::vector<const char*> outputs(output_names);
std::vector<OrtOpAttr*> attrs(attributes);
return CreateNode(api, operator_name, node_name, inputs, outputs, attrs, domain_name);
}
} // namespace

struct TestAllocator : public OrtAllocator {
Expand Down Expand Up @@ -192,18 +206,28 @@ TEST(ModelBuilderAPITest, Basic_CApi) {
Ort::ThrowOnError(graph_api.AddNodeToGraph(graph, node));
node = nullptr; // graph now owns node

// Y input
std::vector<int64_t> y_dims = {2, 3};
deleter.weights.emplace_back(
std::make_unique<std::vector<float>>(std::initializer_list<float>{1.0f, 2.0f, 3.0f,
4.0f, 5.0f, 6.0f}));
auto& y_values = *deleter.weights.back();

if (use_constant_node) {
// Test that a Constant node is converted to an intializer

// create an attribute for the Y input
// create Constant node that produces "Y" output with the value_floats attribute
ASSERT_FALSE(true) << "Not implemented";
OrtOpAttr* value_attr = nullptr;
int bytes = onnxruntime::narrow<int>(y_values.size() * sizeof(y_values[0]));
Ort::ThrowOnError(api.CreateOpAttr("value", y_values.data(), bytes, ORT_OP_ATTR_FLOAT, &value_attr));

node = CreateNode(graph_api, "Constant", "Y_constant", {}, {"Y"}, {value_attr});
Ort::ThrowOnError(graph_api.AddNodeToGraph(graph, node));
node = nullptr; // graph now owns node
} else {
// create an initializer for the Y input. add to `weights` so the memory remains valid
OrtValue* y_tensor = nullptr;
std::vector<int64_t> y_dims = {2, 3};
deleter.weights.emplace_back(
std::make_unique<std::vector<float>>(std::initializer_list<float>{1.0f, 2.0f, 3.0f,
4.0f, 5.0f, 6.0f}));
auto& y_values = *deleter.weights.back();
auto info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);

// if you use this API the initializer data MUST remain valid for the lifetime of the InferenceSession
Expand All @@ -226,29 +250,31 @@ TEST(ModelBuilderAPITest, Basic_CApi) {
}
};

OrtModel* model = nullptr;
build_model(false, model);
auto run_test = [&](bool use_constant_node) -> void {
OrtModel* model = nullptr;
build_model(use_constant_node, model);

ASSERT_NE(model, nullptr) << "build_model should have created a model";
ASSERT_NE(model, nullptr) << "build_model should have created a model";

std::vector<Input<float>> inputs(1);
auto& input = inputs[0];
input.name = "X";
input.dims = {3, 2};
input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
std::vector<Input<float>> inputs(1);
auto& input = inputs[0];
input.name = "X";
input.dims = {3, 2};
input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};

std::vector<int64_t> expected_dims = {3, 3};
ModelBuilderAPI::Model cxx_model(model);
auto session = CreateSession(*ort_env, cxx_model);
std::vector<int64_t> expected_dims = {3, 3};
ModelBuilderAPI::Model cxx_model(model);
auto session = CreateSession(*ort_env, cxx_model);

TestInference<float>(session, inputs, "Z", expected_dims,
{18.0f, 24.0f, 30.0f,
38.0f, 52.0f, 66.0f,
58.0f, 80.0f, 102.0f});
TestInference<float>(session, inputs, "Z", expected_dims,
{18.0f, 24.0f, 30.0f,
38.0f, 52.0f, 66.0f,
58.0f, 80.0f, 102.0f});

api.ReleaseSession(session.release());
api.ReleaseSession(session.release());

ASSERT_EQ(deleter.weights.size(), 0) << "All weights should have been freed";
ASSERT_EQ(deleter.weights.size(), 0) << "All weights should have been freed";
};
}

TEST(ModelBuilderAPITest, Basic_CxxApi) {
Expand Down Expand Up @@ -420,12 +446,23 @@ TEST(ModelBuilderAPITest, BasicModelEdit_CxxApi) {
}
}

TEST(ModelBuilderAPITest, InvalidDimension) {
try {
std::vector<int64_t> input_dims = {-2, 2};
TensorTypeAndShapeInfo tensor_type_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,
input_dims);
// invalid dim of -2 should cause exception
TypeInfo::CreateTensorInfo(tensor_type_info.GetConst());
FAIL();
} catch (const Ort::Exception& e) {
ASSERT_STREQ(e.what(), "dim_values must be -1 (symbolic dimension) or larger.");
}
}

/*
Tests required
- Constant node is converted to initializer
- Attempt to create invalid model
- Edit and change outputs
- Invalid edit
- Edit where we change a subset of inputs or outputs.
- Create invalid model. Graph::Resolve should fail.
- Invalid edit. Graph::Resolve should fail.
- All the non-tensor Create*TypeInfo functions need to be validated
*/
18 changes: 9 additions & 9 deletions winml/adapter/winml_adapter_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -593,13 +593,13 @@ ORT_API_STATUS_IMPL(
input.set_name(input_name);

if (info->type == ONNXType::ONNX_TYPE_TENSOR) {
auto num_dims = info->data->shape.NumDimensions();
auto num_dims = info->tensor_type_info->shape.NumDimensions();
CreateTypeProto_Tensor(
input.mutable_type()->mutable_tensor_type(),
input_name,
(num_dims == 0) ? nullptr : &info->data->shape[0],
(num_dims == 0) ? nullptr : &info->tensor_type_info->shape[0],
num_dims,
ONNXTensorElementDataTypeToTensorProto_DataType(info->data->type)
ONNXTensorElementDataTypeToTensorProto_DataType(info->tensor_type_info->type)
);
}
return nullptr;
Expand All @@ -619,12 +619,12 @@ ORT_API_STATUS_IMPL(
ONNX_NAMESPACE::TensorProto& input = *graph.add_initializer();
input.set_name(input_name);

auto num_dims = info->data->shape.NumDimensions();
auto num_dims = info->tensor_type_info->shape.NumDimensions();
for (size_t i = 0; i < num_dims; i++) {
input.add_dims(info->data->shape[i]);
input.add_dims(info->tensor_type_info->shape[i]);
}

input.set_data_type(ONNXTensorElementDataTypeToTensorProto_DataType(info->data->type));
input.set_data_type(ONNXTensorElementDataTypeToTensorProto_DataType(info->tensor_type_info->type));
auto tensor = value->GetMutable<onnxruntime::Tensor>();
input.set_raw_data(tensor->DataRaw(), tensor->SizeInBytes());

Expand All @@ -645,9 +645,9 @@ ORT_API_STATUS_IMPL(
CreateTypeProto_Tensor(
output.mutable_type()->mutable_tensor_type(),
output_name,
&info->data->shape[0],
info->data->shape.NumDimensions(),
ONNXTensorElementDataTypeToTensorProto_DataType(info->data->type)
&info->tensor_type_info->shape[0],
info->tensor_type_info->shape.NumDimensions(),
ONNXTensorElementDataTypeToTensorProto_DataType(info->tensor_type_info->type)
);
}
return nullptr;
Expand Down

0 comments on commit 002b6cc

Please sign in to comment.