Skip to content

Commit

Permalink
Temporary optimizer support for ort format models in non minimal build (
Browse files Browse the repository at this point in the history
  • Loading branch information
baijumeswani authored Jun 28, 2023
1 parent 960e320 commit efeb667
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 15 deletions.
2 changes: 1 addition & 1 deletion include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1385,7 +1385,7 @@ struct Value : detail::ValueImpl<OrtValue> {
* \param value - the value to be wrapped.
*/
template <typename T>
static Value CreateOpaque(const char* domain, const char* type_name, const T&); ///< Wraps OrtApi::CreateOpaqueValue
static Value CreateOpaque(const char* domain, const char* type_name, const T& value); ///< Wraps OrtApi::CreateOpaqueValue

#if !defined(DISABLE_SPARSE_TENSORS)
/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,7 @@ struct OrtTrainingApi {
* As a result, it is required that the checkpoint state outlive the lifetime of the training session.
*
* \param[in] checkpoint_buffer Path to the checkpoint bytes buffer.
* \param[in] num_bytes Number of bytes in the checkpoint buffer.
* \param[out] checkpoint_state Checkpoint state that contains the states of the training session.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
Expand Down
36 changes: 22 additions & 14 deletions orttraining/orttraining/training_api/optimizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

#include "orttraining/training_api/optimizer.h"
#include "core/flatbuffers/flatbuffers_utils.h"
#include "core/framework/execution_provider.h"
#include "core/framework/TensorSeq.h"
#include "core/providers/cpu/cpu_execution_provider.h"
Expand Down Expand Up @@ -60,29 +61,36 @@ Status GraphInputsAreExpected(gsl::span<std::string> actual_graph_inputs,
} // namespace

std::unique_ptr<OptimizerAlgorithmBase> OptimizerAlorithmFactory::CreateInstance(
const std::string& optim_path_or_bytes, int32_t& group_count) {
const std::string& optim_path, int32_t& group_count) {
std::map<std::pair<std::string, std::string>, int32_t> opt_type_to_freq_map;
#if !defined(ORT_MINIMAL_BUILD)
std::shared_ptr<Model> model;
ORT_ENFORCE(Model::Load(ToWideString(optim_path_or_bytes), model, nullptr,
logging::LoggingManager::DefaultLogger())
.IsOK());
Graph& graph = model->MainGraph();
for (auto& node : graph.Nodes()) {
if (node.Domain() == kMSDomain && (node.OpType() == "AdamWOptimizer" || node.OpType() == "SGDOptimizerV2")) {
auto domain_type_pair = std::make_pair(node.Domain(), node.OpType());
if (opt_type_to_freq_map.find(domain_type_pair) == opt_type_to_freq_map.end()) {
opt_type_to_freq_map[domain_type_pair] = 0;
}
if (const auto optim_path_str = ToPathString(optim_path);
fbs::utils::IsOrtFormatModel(optim_path_str)) {
// TODO (baijumeswani): Figure out the best way to extract the optimizer type
// from an ort format model.
opt_type_to_freq_map[std::make_pair(kMSDomain, "AdamWOptimizer")] = 1;
} else {
std::shared_ptr<Model> model;
ORT_ENFORCE(Model::Load(optim_path_str, model, nullptr,
logging::LoggingManager::DefaultLogger())
.IsOK());
Graph& graph = model->MainGraph();
for (auto& node : graph.Nodes()) {
if (node.Domain() == kMSDomain && (node.OpType() == "AdamWOptimizer" || node.OpType() == "SGDOptimizerV2")) {
auto domain_type_pair = std::make_pair(node.Domain(), node.OpType());
if (opt_type_to_freq_map.find(domain_type_pair) == opt_type_to_freq_map.end()) {
opt_type_to_freq_map[domain_type_pair] = 0;
}

opt_type_to_freq_map[domain_type_pair] += 1;
opt_type_to_freq_map[domain_type_pair] += 1;
}
}
}
#else
// TODO (baijumeswani): Figure out the best way to extract the optimizer type
// from the model (either onnx model or ort format model) or from the checkpoint.
// For now, assume that the optimizer type is AdamWOptimizer in a minimal build.
ORT_UNUSED_PARAMETER(optim_path_or_bytes);
ORT_UNUSED_PARAMETER(optim_path);

opt_type_to_freq_map[std::make_pair(kMSDomain, "AdamWOptimizer")] = 1;
#endif
Expand Down

0 comments on commit efeb667

Please sign in to comment.