Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CoreML MLProgram] Support Float16 (1/N) #22068

Merged
merged 39 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
1767fee
support coremlfp16
wejoncy Sep 10, 2024
bb99008
support unary and binary ops
wejoncy Sep 11, 2024
4e866d1
format
wejoncy Sep 11, 2024
0611bf5
more ops
wejoncy Sep 11, 2024
3944fd6
fix
wejoncy Sep 12, 2024
4f935e7
unify UT
wejoncy Sep 18, 2024
d765095
gemm conv support
wejoncy Sep 20, 2024
8129643
gemm/conv test
wejoncy Sep 20, 2024
9d665f3
address comments
wejoncy Sep 20, 2024
dbf25b9
build issue
wejoncy Sep 20, 2024
293b9f2
fix crash test
wejoncy Sep 21, 2024
4b344a5
lint
wejoncy Sep 21, 2024
ca581bc
Update onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc
wejoncy Sep 25, 2024
e9b2a42
Update onnxruntime/core/providers/coreml/builders/impl/builder_utils.h
wejoncy Sep 25, 2024
8fa2b48
Update onnxruntime/core/providers/coreml/builders/impl/gemm_op_builde…
wejoncy Sep 25, 2024
154a399
Update onnxruntime/core/providers/coreml/builders/model_builder.h
wejoncy Sep 25, 2024
b7c6078
Update onnxruntime/test/providers/cpu/math/gemm_test.cc
wejoncy Sep 25, 2024
fe3a3a3
address comments && add tolerance
wejoncy Sep 25, 2024
43f6e19
add comments to explain convfp16 test
wejoncy Sep 26, 2024
7ef5e1e
format
wejoncy Sep 26, 2024
97e87ad
add curly brace for code block
wejoncy Sep 26, 2024
48c98ec
add conv fp16 with intilizer=true
wejoncy Sep 26, 2024
c9e75c9
qnn convfp16
wejoncy Sep 26, 2024
a8e5485
fix qnn
wejoncy Sep 26, 2024
749940d
add UT for the other ops
wejoncy Sep 27, 2024
753efb6
d
wejoncy Sep 27, 2024
ea70f1c
more uts
wejoncy Sep 27, 2024
46090ad
address comments
wejoncy Sep 27, 2024
a66f7b5
fix
wejoncy Sep 27, 2024
531d564
fix
wejoncy Sep 27, 2024
d88606d
convtranspose ut
wejoncy Sep 29, 2024
f70a934
f
wejoncy Sep 29, 2024
2203294
Update onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc
wejoncy Sep 30, 2024
6d709d0
more UTS
wejoncy Sep 30, 2024
c80e312
add doc for slice
wejoncy Sep 30, 2024
8b2bf8e
code spell fix
wejoncy Sep 30, 2024
3326719
format
wejoncy Sep 30, 2024
0dccf64
address review comments
wejoncy Sep 30, 2024
9ab55d4
slice coreml version >= 6
wejoncy Sep 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 31 additions & 7 deletions onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <set>
#include "core/providers/common.h"
#include "core/providers/coreml/builders/helper.h"
#include "core/providers/coreml/builders/impl/base_op_builder.h"
Expand All @@ -12,6 +13,15 @@
namespace onnxruntime {
namespace coreml {

// Once all ops are supportted FP16, we can remove it. Before that, we keep a set of ops to
// filter suppported ones.
static std::set<std::string> Float16Ops = {

Check warning on line 18 in onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc:18: Add #include <string> for string [build/include_what_you_use] [4]
"Add", "Mul", "Sub", "Div", "Pow", "Sqrt", "Reciprocal",
"Sigmoid", "Tanh", "Relu", "LeakyRelu", "Concat", "GridSample", "GlobalAveragePool",
"Clip", "DepthToSpace", "Resize", "Slice", "Conv",
"ConvTranspose", "GlobalMaxPool", "Gemm", "MatMul",
"AveragePool", "MaxPool", "Reshape", "Split", "Transpose"};
skottmckay marked this conversation as resolved.
Show resolved Hide resolved

namespace {
// TODO, move this to shared_library
bool HasExternalInitializer(const InitializedTensorSet& initializers, const Node& node,
Expand Down Expand Up @@ -83,8 +93,9 @@
}

/* static */
bool BaseOpBuilder::IsInputFloat(const Node& node, size_t idx, const OpBuilderInputParams& /*input_params*/,
const logging::Logger& logger) {
bool BaseOpBuilder::IsInputDtypeSupport(const Node& node, size_t idx,
[[maybe_unused]] const OpBuilderInputParams& input_params,
const logging::Logger& logger) {
if (idx >= node.InputDefs().size()) {
LOGS(logger, VERBOSE) << "Input index [" << idx << "] is out of range";
return false;
Expand All @@ -94,20 +105,33 @@

int32_t input_type = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED;

// currently only float is supported
if (!GetType(input, input_type, logger) || input_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
LOGS(logger, VERBOSE) << "[" << node.OpType() << "] Input type: [" << input_type << "] is not currently supported";
if (!GetType(input, input_type, logger)) {
LOGS(logger, VERBOSE) << "[" << node.OpType() << "] Get Input type failed";
return false;
}

return true;
// float is supported
if (input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
return true;
}

// only support MLProgram for FP16
#if defined(COREML_ENABLE_MLPROGRAM)
if (input_params.create_mlprogram && input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 &&
Float16Ops.count(node.OpType())) {
return true;
}
#endif

LOGS(logger, VERBOSE) << "[" << node.OpType() << "] Input type: [" << input_type << "] is not currently supported";
return false;
}

bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params,
const logging::Logger& logger) const {
// We only check the type of input 0 by default
// specific op builder can override this
return IsInputFloat(node, 0, input_params, logger);
return IsInputDtypeSupport(node, 0, input_params, logger);
}

bool BaseOpBuilder::HasSupportedOpSet(const Node& node, const logging::Logger& logger) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ class BaseOpBuilder : public IOpBuilder {
: allow_empty_tensor_as_input_(allow_empty_tensor_as_input) {
}

// currently we only support float
static bool IsInputFloat(const Node& node, size_t idx, const OpBuilderInputParams& input_params,
const logging::Logger& logger);
// currently we support float/float16
static bool IsInputDtypeSupport(const Node& node, size_t idx, const OpBuilderInputParams& input_params,
const logging::Logger& logger);

private:
virtual bool IsOpSupportedImpl(const Node& /*node*/, const OpBuilderInputParams& /*input_params*/,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
} else if (op_type == "Sub") {
coreml_op_type = "sub";
} else if (op_type == "Div") {
// we only support fp32 currently. when we add support for integers we need to check the type and use
// we support fp32/fp16 currently. when we add support for integers we need to check the type and use
// "floor_div" or "real_div" accordingly
coreml_op_type = "real_div";
} else if (op_type == "Pow") {
Expand Down Expand Up @@ -138,9 +138,9 @@ bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderIn
const logging::Logger& logger) const {
// Add/Sub/Mul/Div spec says inputs must be of the same type.
// Pow spec says inputs can be different types.
// We only support float for all of these inputs.
if (!IsInputFloat(node, 0, input_params, logger) ||
((node.OpType() == "Pow") && !IsInputFloat(node, 1, input_params, logger))) {
// We support float/float16 for all of these inputs.
if (!IsInputDtypeSupport(node, 0, input_params, logger) ||
((node.OpType() == "Pow") && !IsInputDtypeSupport(node, 1, input_params, logger))) {
skottmckay marked this conversation as resolved.
Show resolved Hide resolved
return false;
}

Expand Down
32 changes: 32 additions & 0 deletions onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
CreateCoreMLWeight(weight, unpacked_tensor.DataAsSpan<float>());
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
CreateCoreMLWeight(weight, unpacked_tensor.DataAsSpan<MLFloat16>());
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
CreateCoreMLWeight(weight, unpacked_tensor.DataAsSpan<int32_t>());
break;
Expand All @@ -114,6 +117,11 @@
weight.mutable_floatvalue()->Assign(data.begin(), data.end());
}

void CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, gsl::span<const MLFloat16> data) {
skottmckay marked this conversation as resolved.
Show resolved Hide resolved
const char* data_byte_ptr = (const char*)(data.data());
wejoncy marked this conversation as resolved.
Show resolved Hide resolved
weight.mutable_float16value()->assign(data_byte_ptr, data_byte_ptr + data.size_bytes());
}

namespace {
template <typename T>
void CreateCoreMLWeightConvertingDataToFloats(CoreML::Specification::WeightParams& weight, gsl::span<const T> data) {
Expand All @@ -123,6 +131,15 @@
[](T v) { return narrow<float>(v); });
*weight.mutable_floatvalue() = std::move(weight_floats);
}

template <typename T>
void CreateCoreMLWeightConvertingDataToFloat16s(CoreML::Specification::WeightParams& weight, gsl::span<const T> data) {
std::vector<MLFloat16> weight_float16s{};
weight_float16s.reserve(data.size());
std::transform(data.begin(), data.end(), std::back_inserter(weight_float16s),

Check warning on line 139 in onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <algorithm> for transform [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc:139: Add #include <algorithm> for transform [build/include_what_you_use] [4]
[](T v) { return MLFloat16(narrow<float>(v)); });
wejoncy marked this conversation as resolved.
Show resolved Hide resolved
CreateCoreMLWeight(weight, weight_float16s);
}
} // namespace

void CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, gsl::span<const int32_t> data) {
Expand Down Expand Up @@ -195,6 +212,13 @@
tensor_value.mutable_floats()->mutable_values()->Add(data.begin(), data.end());
}

template <>
void CopyDataToTensorValue<MLFloat16>(MILSpec::TensorValue& tensor_value, gsl::span<const MLFloat16> data) {
const char* begin = reinterpret_cast<const char*>(data.data());
const char* end = begin + (data.size() * sizeof(MLFloat16));
tensor_value.mutable_bytes()->mutable_values()->assign(begin, end);
}

template <>
void CopyDataToTensorValue<int32_t>(MILSpec::TensorValue& tensor_value, gsl::span<const int32_t> data) {
tensor_value.mutable_ints()->mutable_values()->Add(data.begin(), data.end());
Expand Down Expand Up @@ -290,6 +314,14 @@
// explicit specializations for types we handle so the implementation can be in the .cc file
template MILSpec::Value CreateTensorValue<int64_t, int32_t>(gsl::span<const int64_t> data,
std::optional<gsl::span<const int64_t>> shape);
template MILSpec::Value CreateTensorValue<float, float>(gsl::span<const float> data,
std::optional<gsl::span<const int64_t>> shape);
template MILSpec::Value CreateTensorValue<MLFloat16, MLFloat16>(gsl::span<const MLFloat16> data,
std::optional<gsl::span<const int64_t>> shape);
template MILSpec::Value CreateTensorValue<bool, bool>(gsl::span<const bool> data,
std::optional<gsl::span<const int64_t>> shape);
template MILSpec::Value CreateTensorValue<std::string, std::string>(gsl::span<const std::string> data,
std::optional<gsl::span<const int64_t>> shape);

template MILSpec::Value CreateScalarTensorValue(const float& data);
template MILSpec::Value CreateScalarTensorValue(const int32_t& data);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ Status CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, const ONN
// Copy the float array to a coreml weight
void CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, gsl::span<const float> data);

// Copy the MLFloat16 array to a coreml weight
void CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, gsl::span<const MLFloat16> data);

// Copy the int32_t array to a coreml weight
void CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, gsl::span<const int32_t> data);

Expand Down
63 changes: 42 additions & 21 deletions onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,17 @@ void GemmOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Nod
}
}

// This is an internal function, requires input tensor to be 2d float tensor
// TODO, add support of other data types
static Status GetTensorFloatDataTransposed(const ONNX_NAMESPACE::TensorProto& tensor,
std::vector<float>& transposed_data) {
// This is an internal function, requires input tensor to be 2d float/float16 tensor
template <typename T>
static Status GetTensorDataTransposed(const ONNX_NAMESPACE::TensorProto& tensor,
std::vector<T>& transposed_data) {
Initializer unpacked_tensor(tensor);
auto src_data = unpacked_tensor.DataAsSpan<float>();
const auto src_data = unpacked_tensor.DataAsSpan<T>();
const auto& tensor_shape = tensor.dims();
auto x_t = SafeInt<size_t>(tensor_shape[0]);
auto y_t = SafeInt<size_t>(tensor_shape[1]);
transposed_data.resize(x_t * y_t);

for (size_t x = 0; x < x_t; x++) {
for (size_t y = 0; y < y_t; y++) {
transposed_data[y * x_t + x] = src_data[x * y_t + y];
Expand Down Expand Up @@ -121,8 +122,9 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
// B is {K, N} in ONNX spec by default, or {N, K} in Gemm if transB is true
const auto K = transB ? b1 : b0;
const auto N = transB ? b0 : b1;

// we already checked it and dtype must be existed.
#if defined(COREML_ENABLE_MLPROGRAM)
auto input_dtype = a.TypeAsProto()->tensor_type().elem_type();
if (model_builder.CreateMLProgram()) {
using namespace CoreML::Specification::MILSpec;

Expand All @@ -136,13 +138,19 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
if (transB) {
AddOperationInput(*gemm_op, "weight", b.Name());
} else {
// transpose from {K, N} to {N, K}
std::vector<float> weight_nk;
std::vector<int64_t> weight_nk_shape = {N, K};
ORT_RETURN_IF_ERROR(GetTensorFloatDataTransposed(*b_initializer, weight_nk));

AddOperationInput(*gemm_op, "weight",
model_builder.AddConstant(gemm_op->type(), b.Name() + "_t", weight_nk, weight_nk_shape));
// transpose from {K, N} to {N, K}
if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
std::vector<float> weight_nk; // use bytes to store the type-erased data, could be any data-type
wejoncy marked this conversation as resolved.
Show resolved Hide resolved
ORT_RETURN_IF_ERROR(GetTensorDataTransposed(*b_initializer, weight_nk));
AddOperationInput(*gemm_op, "weight",
model_builder.AddConstant(gemm_op->type(), b.Name() + "_t", weight_nk, weight_nk_shape));
} else { // TensorProto_DataType_FLOAT16
std::vector<MLFloat16> weight_nk; // use bytes to store the type-erased data, could be any data-type
wejoncy marked this conversation as resolved.
Show resolved Hide resolved
ORT_RETURN_IF_ERROR(GetTensorDataTransposed(*b_initializer, weight_nk));
AddOperationInput(*gemm_op, "weight",
model_builder.AddConstant(gemm_op->type(), b.Name() + "_t", weight_nk, weight_nk_shape));
}
}

if (input_defs.size() == 3) {
Expand All @@ -155,15 +163,28 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
AddOperationInput(*gemm_op, "bias", bias_arg.Name());
} else {
Initializer unpacked_tensor(bias);
auto bias_data = unpacked_tensor.DataAsSpan<float>();
std::string_view bias_data_name;
if (bias_data.size() == 1) {
// expand scalar to N
std::vector<float> expanded_bias_data(N, bias_data[0]);
bias_data_name = model_builder.AddConstant(gemm_op->type(), "bias", expanded_bias_data);
} else {
// can use data as-is but need to adjust shape (inferred by AddConstant as {bias_data.size()})
bias_data_name = model_builder.AddConstant(gemm_op->type(), "bias", bias_data);

if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
auto bias_data = unpacked_tensor.DataAsSpan<float>();
if (bias_data.size() == 1) {
// expand scalar to N
std::vector<float> expanded_bias_data(N, bias_data[0]);
bias_data_name = model_builder.AddConstant(gemm_op->type(), "bias", expanded_bias_data);
} else {
// can use data as-is but need to adjust shape (inferred by AddConstant as {bias_data.size()})
bias_data_name = model_builder.AddConstant(gemm_op->type(), "bias", bias_data);
}
} else { // TensorProto_DataType_FLOAT16
auto bias_data = unpacked_tensor.DataAsSpan<MLFloat16>();
if (bias_data.size() == 1) {
// expand scalar to N
std::vector<MLFloat16> expanded_bias_data(N, bias_data[0]);
bias_data_name = model_builder.AddConstant(gemm_op->type(), "bias", expanded_bias_data);
} else {
// can use data as-is but need to adjust shape (inferred by AddConstant as {bias_data.size()})
bias_data_name = model_builder.AddConstant(gemm_op->type(), "bias", bias_data);
}
}

AddOperationInput(*gemm_op, "bias", bias_data_name);
Expand Down Expand Up @@ -202,7 +223,7 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_inner_product->mutable_weights(), *b_initializer));
} else {
std::vector<float> b_transposed;
ORT_RETURN_IF_ERROR(GetTensorFloatDataTransposed(*b_initializer, b_transposed));
ORT_RETURN_IF_ERROR(GetTensorDataTransposed(*b_initializer, b_transposed));
CreateCoreMLWeight(*coreml_inner_product->mutable_weights(), b_transposed);
}

Expand Down
59 changes: 47 additions & 12 deletions onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "core/providers/common.h"

#include "core/providers/coreml/builders/impl/builder_utils.h"
#include "core/providers/coreml/builders/helper.h"
#include "core/providers/coreml/builders/impl/base_op_builder.h"
#include "core/providers/coreml/builders/model_builder.h"
Expand All @@ -14,28 +15,62 @@
class UnaryOpBuilder : public BaseOpBuilder {
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const override;
bool SupportsMLProgram() const override { return true; }
};

Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& /* logger */) const {
const auto& op_type(node.OpType());
const auto& input_defs(node.InputDefs());

std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = model_builder.CreateNNLayer(node);
#if defined(COREML_ENABLE_MLPROGRAM)
if (model_builder.CreateMLProgram()) {
using namespace CoreML::Specification::MILSpec;

Check warning on line 28 in onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5] Raw Output: onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc:28: Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]

if (op_type == "Sqrt") {
layer->mutable_unary()->set_type(COREML_SPEC::UnaryFunctionLayerParams::SQRT);
} else if (op_type == "Reciprocal") {
layer->mutable_unary()->set_type(COREML_SPEC::UnaryFunctionLayerParams::INVERSE);
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"UnaryOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type);
}
std::string_view coreml_op_type;
if (op_type == "Sqrt") {
coreml_op_type = "sqrt";
} else if (op_type == "Reciprocal") {
wejoncy marked this conversation as resolved.
Show resolved Hide resolved
coreml_op_type = "inverse";
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"UnaryOpBuilder::AddToModelBuilderImpl, unexpected op: ", op_type);
}

std::unique_ptr<Operation> op = model_builder.CreateOperation(node, coreml_op_type);
AddOperationInput(*op, "x", input_defs[0]->Name());
if (op_type == "Reciprocal") {
float epsilon = 1e-4; // epsilon: const T (Optional, default=1e-4)
auto dtype = node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
if (dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
AddOperationInput(*op, "epsilon", model_builder.AddScalarConstant(op->type(), "epsilon", epsilon));
} else if (dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
AddOperationInput(*op, "epsilon", model_builder.AddScalarConstant(op->type(), "epsilon", MLFloat16(epsilon)));
}
}

AddOperationOutput(*op, *node.OutputDefs()[0]);

*layer->mutable_input()->Add() = input_defs[0]->Name();
*layer->mutable_output()->Add() = node.OutputDefs()[0]->Name();
model_builder.AddOperation(std::move(op));
} else // NOLINT
#endif // defined (COREML_ENABLE_MLPROGRAM)
{
std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = model_builder.CreateNNLayer(node);

model_builder.AddLayer(std::move(layer));
if (op_type == "Sqrt") {
layer->mutable_unary()->set_type(COREML_SPEC::UnaryFunctionLayerParams::SQRT);
} else if (op_type == "Reciprocal") {
layer->mutable_unary()->set_type(COREML_SPEC::UnaryFunctionLayerParams::INVERSE);
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"UnaryOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type);
}

*layer->mutable_input()->Add() = input_defs[0]->Name();
*layer->mutable_output()->Add() = node.OutputDefs()[0]->Name();

model_builder.AddLayer(std::move(layer));
}
return Status::OK();
}

Expand Down
Loading
Loading