Skip to content

Commit

Permalink
Model Builder API
Browse files Browse the repository at this point in the history
- Create new model
- Augment existing model
  • Loading branch information
skottmckay committed Dec 23, 2024
1 parent c6ba7ed commit dece8b8
Show file tree
Hide file tree
Showing 33 changed files with 2,816 additions and 385 deletions.
6 changes: 6 additions & 0 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,7 @@ set (onnxruntime_shared_lib_test_SRC

if (NOT onnxruntime_MINIMAL_BUILD)
list(APPEND onnxruntime_shared_lib_test_SRC ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_inference.cc)
list(APPEND onnxruntime_shared_lib_test_SRC ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_model_builder_api.cc)
endif()

if(onnxruntime_RUN_ONNX_TESTS)
Expand Down Expand Up @@ -1350,14 +1351,19 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
LIBS ${onnxruntime_shared_lib_test_LIBS}
DEPENDS ${all_dependencies}
)

target_include_directories(onnxruntime_shared_lib_test PRIVATE ${ONNXRUNTIME_ROOT})

if (onnxruntime_USE_CUDA)
target_include_directories(onnxruntime_shared_lib_test PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_sources(onnxruntime_shared_lib_test PRIVATE ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/cuda_ops.cu)
endif()

if (onnxruntime_USE_ROCM)
target_include_directories(onnxruntime_shared_lib_test PRIVATE ${onnxruntime_ROCM_HOME}/include)
target_compile_definitions(onnxruntime_shared_lib_test PRIVATE __HIP_PLATFORM_AMD__)
endif()

if (CMAKE_SYSTEM_NAME STREQUAL "Android")
target_sources(onnxruntime_shared_lib_test PRIVATE
"${ONNXRUNTIME_ROOT}/core/platform/android/cxa_demangle.cc"
Expand Down
30 changes: 28 additions & 2 deletions include/onnxruntime/core/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
#include "core/graph/node_arg.h"
#include "core/graph/ort_format_load_options.h"

// Type from Graph API in ORT C API so can't be in a namespace
struct OrtGraph;

namespace onnxruntime {
class Graph;
struct IndexedSubGraph;
Expand Down Expand Up @@ -763,6 +766,10 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
*/
bool GetInitializedTensor(const std::string& tensor_name, const ONNX_NAMESPACE::TensorProto*& value) const;

/** Populate `value` if an externally allocated OrtValue exists for an initializer with the given name.
*/
bool GetOrtValueInitializer(const std::string& name, OrtValue& value) const;

/** Gets all the initializer tensors in this Graph. */
const InitializedTensorSet& GetAllInitializedTensors() const noexcept { return name_to_initial_tensor_; }

Expand Down Expand Up @@ -1430,6 +1437,16 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
const OrtFormatLoadOptions& load_options,
const logging::Logger& logger, std::unique_ptr<Graph>& graph);

static Status LoadFromModelBuilderApiModel(const OrtGraph& api_graph,
const Model& owning_model,
const std::unordered_map<std::string, int>& domain_to_version,
IOnnxRuntimeOpSchemaCollectionPtr schema_registry,
bool strict_shape_type_inference,
const logging::Logger& logger,
std::unique_ptr<Graph>& graph);

Status UpdateUsingModelBuilderApiModel(const OrtModel& api_model);

#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
const RuntimeOptimizationRecordContainer& RuntimeOptimizations() const {
return runtime_optimizations_;
Expand Down Expand Up @@ -1630,7 +1647,8 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
// Implementation for initializer replacement
Status ReplaceInitializedTensorImpl(ONNX_NAMESPACE::TensorProto new_initializer, bool is_external);

std::vector<NodeArg*> CreateNodeArgs(const google::protobuf::RepeatedPtrField<std::string>& names,
template <typename StringRange> // range-initializer returning std::string
std::vector<NodeArg*> CreateNodeArgs(const StringRange& names,
const ArgNameToTypeMap& name_to_type_map);

void ToGraphProtoInternal(ONNX_NAMESPACE::GraphProto& graph_proto) const;
Expand Down Expand Up @@ -1694,6 +1712,8 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
return nodes_[node_index].get();
}

Status LoadFromModelBuilderApiModel(const OrtGraph& api_graph, bool updating_existing_graph = false);

const Model& owning_model_;

// GraphProto to store name, version, initializer.
Expand All @@ -1708,6 +1728,11 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi

InitializedTensorSet name_to_initial_tensor_;

// Initializers that are external to the Graph. e.g. created using Model Builder API from existing memory.
// As we need to convert to TensorProto for the optimizers to work and keep the deleter information we store them
// in the Graph instance and retrieve during session state finalization.
std::unordered_map<std::string, OrtValue> ortvalue_initializers_;

std::unordered_set<std::reference_wrapper<const std::string>,
std::hash<std::string>, std::equal_to<std::string>>
sparse_tensor_names_;
Expand Down Expand Up @@ -1744,6 +1769,7 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
// in some case, a fused sub-graph will happens multiple times in one model, we use a map
// to store reusable-schema in lookup.
InlinedHashMap<std::string, std::reference_wrapper<ONNX_NAMESPACE::OpSchema>> reusable_fused_schema_map_;

#endif // !defined(ORT_MINIMAL_BUILD)

// Graph nodes.
Expand Down Expand Up @@ -1806,7 +1832,7 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
std::unordered_map<std::string, std::unordered_set<NodeIndex>> node_arg_to_consumer_nodes_;
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

const std::unordered_map<std::string, int> domain_to_version_;
std::unordered_map<std::string, int> domain_to_version_;

// Model IR version.
Version ir_version_{ONNX_NAMESPACE::Version::IR_VERSION};
Expand Down
6 changes: 6 additions & 0 deletions include/onnxruntime/core/graph/graph_viewer.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,12 @@ class GraphViewer {
IOnnxRuntimeOpSchemaCollectionPtr GetSchemaRegistry() const { return graph_->GetSchemaRegistry(); }
#endif

/** Populate `value` if an externally allocated OrtValue exists for an initializer with the given name.
*/
bool GetOrtValueInitializer(const std::string& name, OrtValue& value) const {
return graph_->GetOrtValueInitializer(name, value);
}

private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GraphViewer);
GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info);
Expand Down
Loading

0 comments on commit dece8b8

Please sign in to comment.