Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
wejoncy committed Sep 12, 2024
1 parent 09b547d commit e134133
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ using namespace CoreML::Specification;
namespace onnxruntime {
namespace coreml {

// Once all ops are supportted FP16, we can remove it. Before that, we keep a set of ops to
// filter suppported ones.
static std::set<const std::string> Float16Ops = {

Check warning on line 18 in onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc:18: Add #include <string> for string [build/include_what_you_use] [4]
"Add", "Mul", "Sub", "Div", "Pow", "Sqrt", "Reciprocal",
"Sigmoid", "Tanh", "Relu", "LeakyRelu", "Concat", "GridSample", "GlobalAveragePool",
Expand Down Expand Up @@ -110,11 +112,13 @@ bool BaseOpBuilder::IsInputDtypeSupport(const Node& node, size_t idx, const OpBu
return true;
}

// only support MLProgram for FP16
#if defined(COREML_ENABLE_MLPROGRAM)
if (input_params.create_mlprogram && input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 && Float16Ops.count(node.OpType())) {

Check warning on line 117 in onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc:117: Lines should be <= 120 characters long [whitespace/line_length] [2]
return true;
}
#endif

LOGS(logger, VERBOSE) << "[" << node.OpType() << "] Input type: [" << input_type << "] is not currently supported";
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class BaseOpBuilder : public IOpBuilder {
: allow_empty_tensor_as_input_(allow_empty_tensor_as_input) {
}

// currently we only support float
// currently we support float/float16
static bool IsInputDtypeSupport(const Node& node, size_t idx, const OpBuilderInputParams& input_params,
const logging::Logger& logger);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
if (model_builder.CreateMLProgram()) {
using namespace CoreML::Specification::MILSpec;

Check warning on line 28 in onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5] Raw Output: onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc:28: Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]

// https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#module-coremltools.converters.mil.mil.ops.defs.iOS15.elementwise_binary
std::string_view coreml_op_type;
if (op_type == "Sqrt") {
coreml_op_type = "sqrt";
Expand Down

0 comments on commit e134133

Please sign in to comment.