From 1767fee69881ab53a071d5e8fd07cafc00a3f7a0 Mon Sep 17 00:00:00 2001 From: wejoncy Date: Tue, 10 Sep 2024 17:03:22 +0800 Subject: [PATCH 01/39] support coremlfp16 --- .../coreml/builders/impl/base_op_builder.cc | 24 +++++++++++++++---- .../coreml/builders/impl/base_op_builder.h | 2 +- .../coreml/builders/impl/binary_op_builder.cc | 4 ++-- .../coreml/builders/impl/builder_utils.cc | 16 +++++++++++++ .../coreml/builders/impl/builder_utils.h | 3 +++ .../coreml/builders/model_builder.cc | 3 +++ .../core/providers/coreml/model/model.mm | 16 +++++++++++++ 7 files changed, 60 insertions(+), 8 deletions(-) diff --git a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc index 2cae85a0a1c8d..9de6e2c20c97a 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" @@ -12,6 +13,10 @@ using namespace CoreML::Specification; namespace onnxruntime { namespace coreml { +static std::set Float16Ops = { + "Add", +}; + namespace { // TODO, move this to shared_library bool HasExternalInitializer(const InitializedTensorSet& initializers, const Node& node, @@ -83,7 +88,7 @@ bool BaseOpBuilder::HasSupportedInputs(const Node& node, const OpBuilderInputPar } /* static */ -bool BaseOpBuilder::IsInputFloat(const Node& node, size_t idx, const OpBuilderInputParams& /*input_params*/, +bool BaseOpBuilder::IsInputDtypeSupport(const Node& node, size_t idx, const OpBuilderInputParams& /*input_params*/, const logging::Logger& logger) { if (idx >= node.InputDefs().size()) { LOGS(logger, VERBOSE) << "Input index [" << idx << "] is out of range"; @@ -94,12 +99,21 @@ bool BaseOpBuilder::IsInputFloat(const Node& node, size_t idx, const OpBuilderIn int32_t input_type = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED; - // currently only float is supported - if (!GetType(input, input_type, logger) || input_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { - LOGS(logger, VERBOSE) << "[" << node.OpType() << "] Input type: [" << input_type << "] is not currently supported"; + if (!GetType(input, input_type, logger)) { + LOGS(logger, VERBOSE) << "[" << node.OpType() << "] Get Input type failed"; return false; } + // float is supported + if (input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT){ + return true; + } + + if (input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 && Float16Ops.count(node.OpType())) { + return true; + } + + LOGS(logger, VERBOSE) << "[" << node.OpType() << "] Input type: [" << input_type << "] is not currently supported"; return true; } @@ -107,7 +121,7 @@ bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderInpu const logging::Logger& logger) const { // We only check the type of input 0 by default // specific op builder can override this - return IsInputFloat(node, 0, input_params, logger); + return IsInputDtypeSupport(node, 0, input_params, logger); } bool BaseOpBuilder::HasSupportedOpSet(const Node& node, const logging::Logger& logger) const { diff --git a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h index 071008520fbdc..6bd3c43f373cb 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h +++ b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h @@ -33,7 +33,7 @@ class BaseOpBuilder : public IOpBuilder { } // currently we only support float - static bool IsInputFloat(const Node& node, size_t idx, const OpBuilderInputParams& input_params, + static bool IsInputDtypeSupport(const Node& node, size_t idx, const OpBuilderInputParams& input_params, const logging::Logger& logger); private: diff --git a/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc index fb8e07633621f..3ecea9c3770fe 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc @@ -139,8 +139,8 @@ bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderIn // Add/Sub/Mul/Div spec says inputs must be of the same type. // Pow spec says inputs can be different types. // We only support float for all of these inputs. - if (!IsInputFloat(node, 0, input_params, logger) || - ((node.OpType() == "Pow") && !IsInputFloat(node, 1, input_params, logger))) { + if (!IsInputDtypeSupport(node, 0, input_params, logger) || + ((node.OpType() == "Pow") && !IsInputDtypeSupport(node, 1, input_params, logger))) { return false; } diff --git a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc index e02186d3aee89..328f8b3279928 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc @@ -96,6 +96,9 @@ Status CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: CreateCoreMLWeight(weight, unpacked_tensor.DataAsSpan()); break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + CreateCoreMLWeight(weight, unpacked_tensor.DataAsSpan()); + break; case ONNX_NAMESPACE::TensorProto_DataType_INT32: CreateCoreMLWeight(weight, unpacked_tensor.DataAsSpan()); break; @@ -114,6 +117,11 @@ void CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, gsl::spanAssign(data.begin(), data.end()); } +void CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, gsl::span data) { + const char* data_byte_ptr = (const char*)(data.data()); + weight.mutable_float16value()->assign(data_byte_ptr, data_byte_ptr+data.size_bytes()); +} + namespace { template void CreateCoreMLWeightConvertingDataToFloats(CoreML::Specification::WeightParams& weight, gsl::span data) { @@ -133,6 +141,8 @@ void CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, gsl::span(gsl::span data, std::optional> shape); +template MILSpec::Value CreateTensorValue(gsl::span data, + std::optional> shape); +template MILSpec::Value CreateTensorValue(gsl::span data, + std::optional> shape); +template MILSpec::Value CreateTensorValue(gsl::span data, + std::optional> shape); template MILSpec::Value CreateScalarTensorValue(const float& data); template MILSpec::Value CreateScalarTensorValue(const int32_t& data); diff --git a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h index 475ce79b0a812..f25936e25a17f 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h +++ b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h @@ -41,6 +41,9 @@ Status CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, const ONN // Copy the float array to a coreml weight void CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, gsl::span data); +// Copy the float array to a coreml weight +void CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, gsl::span data); + // Copy the int32_t array to a coreml weight void CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, gsl::span data); diff --git a/onnxruntime/core/providers/coreml/builders/model_builder.cc b/onnxruntime/core/providers/coreml/builders/model_builder.cc index 9668bfcd09adf..7ecfad8493ea5 100644 --- a/onnxruntime/core/providers/coreml/builders/model_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/model_builder.cc @@ -811,6 +811,9 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: multi_array->set_datatype(ArrayFeatureType::FLOAT32); break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + multi_array->set_datatype(ArrayFeatureType::FLOAT16); + break; case ONNX_NAMESPACE::TensorProto_DataType_INT32: multi_array->set_datatype(ArrayFeatureType::INT32); break; diff --git a/onnxruntime/core/providers/coreml/model/model.mm b/onnxruntime/core/providers/coreml/model/model.mm index 68460ff7c9b31..60c93aa601622 100644 --- a/onnxruntime/core/providers/coreml/model/model.mm +++ b/onnxruntime/core/providers/coreml/model/model.mm @@ -120,6 +120,10 @@ Status CreateInputFeatureProvider(const std::unordered_map(mlmultiarray_buffer); + auto* dst_buffer = static_cast(tensor_buffer); + const auto block_byte_size = block_size * sizeof(uint16_t); + + for (int64_t idx = 0; idx < num_blocks; ++idx) { + memcpy(dst_buffer, src_buffer, block_byte_size); + src_buffer += stride; + dst_buffer += block_size; + } + break; + } case ONNX_NAMESPACE::TensorProto_DataType_INT32: { const auto* src_buffer = static_cast(mlmultiarray_buffer); auto* dst_buffer = static_cast(tensor_buffer); From bb9900882c08a16a16cb27e2555fd1b3c29bebbf Mon Sep 17 00:00:00 2001 From: wejoncy Date: Wed, 11 Sep 2024 03:49:00 -0700 Subject: [PATCH 02/39] support unary and binary ops --- .../coreml/builders/impl/base_op_builder.cc | 11 +-- .../coreml/builders/impl/binary_op_builder.cc | 4 +- .../coreml/builders/impl/unary_op_builder.cc | 32 +++++++ .../providers/coreml/coreml_basic_test.cc | 87 +++++++++++++++++++ onnxruntime/test/util/test_utils.cc | 5 ++ 5 files changed, 132 insertions(+), 7 deletions(-) diff --git a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc index 9de6e2c20c97a..cc6f2d796c5e7 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc @@ -14,7 +14,7 @@ namespace onnxruntime { namespace coreml { static std::set Float16Ops = { - "Add", + "Add", "Mul", "Sub", "Div", "Pow", "Sqrt", "Reciprocal" }; namespace { @@ -88,7 +88,7 @@ bool BaseOpBuilder::HasSupportedInputs(const Node& node, const OpBuilderInputPar } /* static */ -bool BaseOpBuilder::IsInputDtypeSupport(const Node& node, size_t idx, const OpBuilderInputParams& /*input_params*/, +bool BaseOpBuilder::IsInputDtypeSupport(const Node& node, size_t idx, const OpBuilderInputParams& input_params, const logging::Logger& logger) { if (idx >= node.InputDefs().size()) { LOGS(logger, VERBOSE) << "Input index [" << idx << "] is out of range"; @@ -109,12 +109,13 @@ bool BaseOpBuilder::IsInputDtypeSupport(const Node& node, size_t idx, const OpBu return true; } - if (input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 && Float16Ops.count(node.OpType())) { +#if defined(COREML_ENABLE_MLPROGRAM) + if (input_params.create_mlprogram && input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 && Float16Ops.count(node.OpType())) { return true; } - +#endif LOGS(logger, VERBOSE) << "[" << node.OpType() << "] Input type: [" << input_type << "] is not currently supported"; - return true; + return false; } bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params, diff --git a/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc index 3ecea9c3770fe..bc1eed8c1920a 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc @@ -73,7 +73,7 @@ Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const } else if (op_type == "Sub") { coreml_op_type = "sub"; } else if (op_type == "Div") { - // we only support fp32 currently. when we add support for integers we need to check the type and use + // we support fp32/fp16 currently. when we add support for integers we need to check the type and use // "floor_div" or "real_div" accordingly coreml_op_type = "real_div"; } else if (op_type == "Pow") { @@ -138,7 +138,7 @@ bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderIn const logging::Logger& logger) const { // Add/Sub/Mul/Div spec says inputs must be of the same type. // Pow spec says inputs can be different types. - // We only support float for all of these inputs. + // We support float/float16 for all of these inputs. if (!IsInputDtypeSupport(node, 0, input_params, logger) || ((node.OpType() == "Pow") && !IsInputDtypeSupport(node, 1, input_params, logger))) { return false; diff --git a/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc index 3403378d59114..595e08d1d7717 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc @@ -3,6 +3,7 @@ #include "core/providers/common.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" #include "core/providers/coreml/builders/model_builder.h" @@ -14,6 +15,7 @@ namespace coreml { class UnaryOpBuilder : public BaseOpBuilder { Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; + bool SupportsMLProgram() const override { return true; } }; Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, @@ -21,6 +23,35 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const const auto& op_type(node.OpType()); const auto& input_defs(node.InputDefs()); + +#if defined(COREML_ENABLE_MLPROGRAM) + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; + + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#module-coremltools.converters.mil.mil.ops.defs.iOS15.elementwise_binary + std::string_view coreml_op_type; + if (op_type == "Sqrt") { + coreml_op_type = "sqrt"; + } else if (op_type == "Reciprocal") { + coreml_op_type = "inverse"; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "UnaryOpBuilder::AddToModelBuilderImpl, unexpected op: ", op_type); + } + + std::unique_ptr op = model_builder.CreateOperation(node, coreml_op_type); + AddOperationInput(*op, "x", input_defs[0]->Name()); + if (op_type == "Reciprocal") { + float epsilon = 1e-4; //epsilon: const T (Optional, default=1e-4) + AddOperationInput(*op, "epsilon", model_builder.AddScalarConstant(op->type(), "epsilon", epsilon)); + } + + AddOperationOutput(*op, *node.OutputDefs()[0]); + + model_builder.AddOperation(std::move(op)); + } else +#endif // defined (COREML_ENABLE_MLPROGRAM) + { std::unique_ptr layer = model_builder.CreateNNLayer(node); if (op_type == "Sqrt") { @@ -36,6 +67,7 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); model_builder.AddLayer(std::move(layer)); + } return Status::OK(); } diff --git a/onnxruntime/test/providers/coreml/coreml_basic_test.cc b/onnxruntime/test/providers/coreml/coreml_basic_test.cc index daa24db134114..c9d8a605678be 100644 --- a/onnxruntime/test/providers/coreml/coreml_basic_test.cc +++ b/onnxruntime/test/providers/coreml/coreml_basic_test.cc @@ -257,6 +257,93 @@ TEST(CoreMLExecutionProviderTest, TestNameSanitization) { // TensorRT does not support Clip opset 11 yet. test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } + +TEST(CoreMLExecutionProviderTest, TestBinaryFp16) { + auto test_binary_op = [](std::string op){ + OpTester test(op, 11); + + std::vector dims{3, 3}; + std::vector input1 = {-1.0f, 0.0f, 1.0f, + -6.0f, 0.0f, 6.0f, + -5.4f, 2.0f, 6.0f}; + std::vector input1_fp16(9); + ConvertFloatToMLFloat16(input1.data(), input1_fp16.data(), 9); + std::vector input2 = {-1.0f, 0.0f, 1.0f, + -5.0f, 0.0f, 5.0f, + -5.0f, 2.0f, 5.0f}; + std::vector input2_fp16(9); + ConvertFloatToMLFloat16(input2.data(), input2_fp16.data(), 9); + std::vector output(9); + if (op == "Add"){ + for(int i = 0; i < 9; i++){ + output[i] = input1_fp16[i] + input2_fp16[i]; + } + } else if (op == "Sub") { + for(int i = 0; i < 9; i++){ + output[i] = input1_fp16[i] - input2_fp16[i]; + } + } else if (op == "Mul") { + for(int i = 0; i < 9; i++){ + output[i] = input1_fp16[i] * input2_fp16[i]; + } + } else if (op == "Div") { + for(int i = 0; i < 9; i++){ + output[i] = input1_fp16[i] / input2_fp16[i]; + } + } + std::vector output_fp16(9); + ConvertFloatToMLFloat16(output.data(), output_fp16.data(), 9); + + test.AddInput("0", dims, input1_fp16); + test.AddInput("1.min", dims, input2_fp16); + test.AddOutput("3", dims, output_fp16); + + // TensorRT does not support Clip opset 11 yet. + std::vector> coreml_ep; + coreml_ep.emplace_back(MakeCoreMLExecutionProvider(COREML_FLAG_CREATE_MLPROGRAM)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &coreml_ep); + }; + test_binary_op("Add"); + test_binary_op("Sub"); + test_binary_op("Div"); + test_binary_op("Mul"); +} + +TEST(CoreMLExecutionProviderTest, TestUnaryFp16) { + auto test_binary_op = [](std::string op){ + OpTester test(op, 11); + + std::vector dims{3, 3}; + std::vector input1 = {-1.0f, 0.0f, 1.0f, + -6.0f, 0.2f, 6.0f, + -5.4f, 2.0f, 6.0f}; + std::vector input1_fp16(9); + ConvertFloatToMLFloat16(input1.data(), input1_fp16.data(), 9); + + std::vector output(9); + if (op == "Sqrt"){ + for(int i = 0; i < 9; i++){ + output[i] = sqrt(input1_fp16[i]); + } + } else if (op == "Reciprocal") { + for(int i = 0; i < 9; i++){ + output[i] = 1.0f/(1e-4+input1_fp16[i]); + } + } + std::vector output_fp16(9); + ConvertFloatToMLFloat16(output.data(), output_fp16.data(), 9); + + test.AddInput("0", dims, input1_fp16); + test.AddOutput("3", dims, output_fp16); + + // TensorRT does not support Clip opset 11 yet. + std::vector> coreml_ep; + coreml_ep.emplace_back(MakeCoreMLExecutionProvider(COREML_FLAG_CREATE_MLPROGRAM)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &coreml_ep); + }; + test_binary_op("Sqrt"); + test_binary_op("Reciprocal"); +} #endif } // namespace test diff --git a/onnxruntime/test/util/test_utils.cc b/onnxruntime/test/util/test_utils.cc index 6bc0f8d105495..606b8d580fa34 100644 --- a/onnxruntime/test/util/test_utils.cc +++ b/onnxruntime/test/util/test_utils.cc @@ -55,6 +55,11 @@ void VerifyOutput(const std::string& output_name, ::testing::Pointwise(::testing::FloatNear(fp32_abs_err), tensor.DataAsSpan())); break; } + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: { + EXPECT_THAT(expected_tensor.DataAsSpan(), + ::testing::Pointwise(::testing::FloatNear(fp32_abs_err), tensor.DataAsSpan())); + break; + } default: ORT_THROW("Unhandled data type. Please add 'case' statement for ", element_type); } From 4e866d1dbab6dc41af3db414b79fc2523c20ca46 Mon Sep 17 00:00:00 2001 From: wejoncy Date: Wed, 11 Sep 2024 03:57:27 -0700 Subject: [PATCH 03/39] format --- .../coreml/builders/impl/base_op_builder.cc | 7 ++-- .../coreml/builders/impl/base_op_builder.h | 2 +- .../coreml/builders/impl/builder_utils.cc | 10 +++--- .../coreml/builders/impl/unary_op_builder.cc | 27 +++++++-------- .../providers/coreml/coreml_basic_test.cc | 34 +++++++++---------- 5 files changed, 38 insertions(+), 42 deletions(-) diff --git a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc index cc6f2d796c5e7..a261dbb63d07d 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc @@ -14,8 +14,7 @@ namespace onnxruntime { namespace coreml { static std::set Float16Ops = { - "Add", "Mul", "Sub", "Div", "Pow", "Sqrt", "Reciprocal" -}; + "Add", "Mul", "Sub", "Div", "Pow", "Sqrt", "Reciprocal"}; namespace { // TODO, move this to shared_library @@ -89,7 +88,7 @@ bool BaseOpBuilder::HasSupportedInputs(const Node& node, const OpBuilderInputPar /* static */ bool BaseOpBuilder::IsInputDtypeSupport(const Node& node, size_t idx, const OpBuilderInputParams& input_params, - const logging::Logger& logger) { + const logging::Logger& logger) { if (idx >= node.InputDefs().size()) { LOGS(logger, VERBOSE) << "Input index [" << idx << "] is out of range"; return false; @@ -105,7 +104,7 @@ bool BaseOpBuilder::IsInputDtypeSupport(const Node& node, size_t idx, const OpBu } // float is supported - if (input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT){ + if (input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { return true; } diff --git a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h index 6bd3c43f373cb..a2cbef6dd57db 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h +++ b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h @@ -34,7 +34,7 @@ class BaseOpBuilder : public IOpBuilder { // currently we only support float static bool IsInputDtypeSupport(const Node& node, size_t idx, const OpBuilderInputParams& input_params, - const logging::Logger& logger); + const logging::Logger& logger); private: virtual bool IsOpSupportedImpl(const Node& /*node*/, const OpBuilderInputParams& /*input_params*/, diff --git a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc index 328f8b3279928..fc6b5792f3649 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc @@ -119,7 +119,7 @@ void CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, gsl::span data) { const char* data_byte_ptr = (const char*)(data.data()); - weight.mutable_float16value()->assign(data_byte_ptr, data_byte_ptr+data.size_bytes()); + weight.mutable_float16value()->assign(data_byte_ptr, data_byte_ptr + data.size_bytes()); } namespace { @@ -141,8 +141,6 @@ void CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, gsl::span(gsl::span data, std::optional> shape); template MILSpec::Value CreateTensorValue(gsl::span data, - std::optional> shape); + std::optional> shape); template MILSpec::Value CreateTensorValue(gsl::span data, - std::optional> shape); + std::optional> shape); template MILSpec::Value CreateTensorValue(gsl::span data, - std::optional> shape); + std::optional> shape); template MILSpec::Value CreateScalarTensorValue(const float& data); template MILSpec::Value CreateScalarTensorValue(const int32_t& data); diff --git a/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc index 595e08d1d7717..6d46c3789decf 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc @@ -23,7 +23,6 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const const auto& op_type(node.OpType()); const auto& input_defs(node.InputDefs()); - #if defined(COREML_ENABLE_MLPROGRAM) if (model_builder.CreateMLProgram()) { using namespace CoreML::Specification::MILSpec; @@ -42,7 +41,7 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const std::unique_ptr op = model_builder.CreateOperation(node, coreml_op_type); AddOperationInput(*op, "x", input_defs[0]->Name()); if (op_type == "Reciprocal") { - float epsilon = 1e-4; //epsilon: const T (Optional, default=1e-4) + float epsilon = 1e-4; // epsilon: const T (Optional, default=1e-4) AddOperationInput(*op, "epsilon", model_builder.AddScalarConstant(op->type(), "epsilon", epsilon)); } @@ -52,21 +51,21 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const } else #endif // defined (COREML_ENABLE_MLPROGRAM) { - std::unique_ptr layer = model_builder.CreateNNLayer(node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); - if (op_type == "Sqrt") { - layer->mutable_unary()->set_type(COREML_SPEC::UnaryFunctionLayerParams::SQRT); - } else if (op_type == "Reciprocal") { - layer->mutable_unary()->set_type(COREML_SPEC::UnaryFunctionLayerParams::INVERSE); - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "UnaryOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); - } + if (op_type == "Sqrt") { + layer->mutable_unary()->set_type(COREML_SPEC::UnaryFunctionLayerParams::SQRT); + } else if (op_type == "Reciprocal") { + layer->mutable_unary()->set_type(COREML_SPEC::UnaryFunctionLayerParams::INVERSE); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "UnaryOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); + } - *layer->mutable_input()->Add() = input_defs[0]->Name(); - *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); + *layer->mutable_input()->Add() = input_defs[0]->Name(); + *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); - model_builder.AddLayer(std::move(layer)); + model_builder.AddLayer(std::move(layer)); } return Status::OK(); } diff --git a/onnxruntime/test/providers/coreml/coreml_basic_test.cc b/onnxruntime/test/providers/coreml/coreml_basic_test.cc index c9d8a605678be..4fdfd2dc1be25 100644 --- a/onnxruntime/test/providers/coreml/coreml_basic_test.cc +++ b/onnxruntime/test/providers/coreml/coreml_basic_test.cc @@ -259,35 +259,35 @@ TEST(CoreMLExecutionProviderTest, TestNameSanitization) { } TEST(CoreMLExecutionProviderTest, TestBinaryFp16) { - auto test_binary_op = [](std::string op){ + auto test_binary_op = [](std::string op) { OpTester test(op, 11); std::vector dims{3, 3}; std::vector input1 = {-1.0f, 0.0f, 1.0f, - -6.0f, 0.0f, 6.0f, - -5.4f, 2.0f, 6.0f}; + -6.0f, 0.0f, 6.0f, + -5.4f, 2.0f, 6.0f}; std::vector input1_fp16(9); ConvertFloatToMLFloat16(input1.data(), input1_fp16.data(), 9); std::vector input2 = {-1.0f, 0.0f, 1.0f, - -5.0f, 0.0f, 5.0f, - -5.0f, 2.0f, 5.0f}; + -5.0f, 0.0f, 5.0f, + -5.0f, 2.0f, 5.0f}; std::vector input2_fp16(9); ConvertFloatToMLFloat16(input2.data(), input2_fp16.data(), 9); std::vector output(9); - if (op == "Add"){ - for(int i = 0; i < 9; i++){ + if (op == "Add") { + for (int i = 0; i < 9; i++) { output[i] = input1_fp16[i] + input2_fp16[i]; } } else if (op == "Sub") { - for(int i = 0; i < 9; i++){ + for (int i = 0; i < 9; i++) { output[i] = input1_fp16[i] - input2_fp16[i]; } } else if (op == "Mul") { - for(int i = 0; i < 9; i++){ + for (int i = 0; i < 9; i++) { output[i] = input1_fp16[i] * input2_fp16[i]; } } else if (op == "Div") { - for(int i = 0; i < 9; i++){ + for (int i = 0; i < 9; i++) { output[i] = input1_fp16[i] / input2_fp16[i]; } } @@ -310,24 +310,24 @@ TEST(CoreMLExecutionProviderTest, TestBinaryFp16) { } TEST(CoreMLExecutionProviderTest, TestUnaryFp16) { - auto test_binary_op = [](std::string op){ + auto test_binary_op = [](std::string op) { OpTester test(op, 11); std::vector dims{3, 3}; std::vector input1 = {-1.0f, 0.0f, 1.0f, - -6.0f, 0.2f, 6.0f, - -5.4f, 2.0f, 6.0f}; + -6.0f, 0.2f, 6.0f, + -5.4f, 2.0f, 6.0f}; std::vector input1_fp16(9); ConvertFloatToMLFloat16(input1.data(), input1_fp16.data(), 9); std::vector output(9); - if (op == "Sqrt"){ - for(int i = 0; i < 9; i++){ + if (op == "Sqrt") { + for (int i = 0; i < 9; i++) { output[i] = sqrt(input1_fp16[i]); } } else if (op == "Reciprocal") { - for(int i = 0; i < 9; i++){ - output[i] = 1.0f/(1e-4+input1_fp16[i]); + for (int i = 0; i < 9; i++) { + output[i] = 1.0f / (1e-4 + input1_fp16[i]); } } std::vector output_fp16(9); From 0611bf5dc0355a076766c1ed5ef748a43b5fcafa Mon Sep 17 00:00:00 2001 From: wejoncy Date: Wed, 11 Sep 2024 04:24:31 -0700 Subject: [PATCH 04/39] more ops --- .../core/providers/coreml/builders/impl/base_op_builder.cc | 4 +++- .../core/providers/coreml/builders/impl/unary_op_builder.cc | 4 ++-- onnxruntime/test/providers/coreml/coreml_basic_test.cc | 5 +++++ 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc index a261dbb63d07d..f267dc7551359 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc @@ -14,7 +14,9 @@ namespace onnxruntime { namespace coreml { static std::set Float16Ops = { - "Add", "Mul", "Sub", "Div", "Pow", "Sqrt", "Reciprocal"}; + "Add", "Mul", "Sub", "Div", "Pow", "Sqrt", "Reciprocal", + "Sigmoid", "Tanh", "Relu", "LeakyRelu", "Concat", "GridSample", "GlobalAveragePool", + "GlobalMaxPool", "AveragePool", "MaxPool", "Reshape", "Split", "Transpose"}; namespace { // TODO, move this to shared_library diff --git a/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc index 6d46c3789decf..aa3060d62686d 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc @@ -48,8 +48,8 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const AddOperationOutput(*op, *node.OutputDefs()[0]); model_builder.AddOperation(std::move(op)); - } else -#endif // defined (COREML_ENABLE_MLPROGRAM) + } else // NOLINT +#endif // defined (COREML_ENABLE_MLPROGRAM) { std::unique_ptr layer = model_builder.CreateNNLayer(node); diff --git a/onnxruntime/test/providers/coreml/coreml_basic_test.cc b/onnxruntime/test/providers/coreml/coreml_basic_test.cc index 4fdfd2dc1be25..6da88a24bd450 100644 --- a/onnxruntime/test/providers/coreml/coreml_basic_test.cc +++ b/onnxruntime/test/providers/coreml/coreml_basic_test.cc @@ -329,6 +329,10 @@ TEST(CoreMLExecutionProviderTest, TestUnaryFp16) { for (int i = 0; i < 9; i++) { output[i] = 1.0f / (1e-4 + input1_fp16[i]); } + } else if (op == "Relu") { + for (int i = 0; i < 9; i++) { + output[i] = fmax(0.0f, input1_fp16[i]); + } } std::vector output_fp16(9); ConvertFloatToMLFloat16(output.data(), output_fp16.data(), 9); @@ -343,6 +347,7 @@ TEST(CoreMLExecutionProviderTest, TestUnaryFp16) { }; test_binary_op("Sqrt"); test_binary_op("Reciprocal"); + test_binary_op("Relu"); } #endif From 3944fd606ea96254d64f09b40e6b918f07f777f7 Mon Sep 17 00:00:00 2001 From: wejoncy Date: Wed, 11 Sep 2024 20:24:31 -0700 Subject: [PATCH 05/39] fix --- .../core/providers/coreml/builders/impl/base_op_builder.cc | 4 ++++ .../core/providers/coreml/builders/impl/base_op_builder.h | 2 +- .../core/providers/coreml/builders/impl/unary_op_builder.cc | 1 - 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc index f267dc7551359..25d7890faebaf 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc @@ -13,6 +13,8 @@ using namespace CoreML::Specification; namespace onnxruntime { namespace coreml { +// Once all ops are supportted FP16, we can remove it. Before that, we keep a set of ops to +// filter suppported ones. static std::set Float16Ops = { "Add", "Mul", "Sub", "Div", "Pow", "Sqrt", "Reciprocal", "Sigmoid", "Tanh", "Relu", "LeakyRelu", "Concat", "GridSample", "GlobalAveragePool", @@ -110,11 +112,13 @@ bool BaseOpBuilder::IsInputDtypeSupport(const Node& node, size_t idx, const OpBu return true; } +// only support MLProgram for FP16 #if defined(COREML_ENABLE_MLPROGRAM) if (input_params.create_mlprogram && input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 && Float16Ops.count(node.OpType())) { return true; } #endif + LOGS(logger, VERBOSE) << "[" << node.OpType() << "] Input type: [" << input_type << "] is not currently supported"; return false; } diff --git a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h index a2cbef6dd57db..153ae841b238f 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h +++ b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h @@ -32,7 +32,7 @@ class BaseOpBuilder : public IOpBuilder { : allow_empty_tensor_as_input_(allow_empty_tensor_as_input) { } - // currently we only support float + // currently we support float/float16 static bool IsInputDtypeSupport(const Node& node, size_t idx, const OpBuilderInputParams& input_params, const logging::Logger& logger); diff --git a/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc index aa3060d62686d..e8a138aa49799 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc @@ -27,7 +27,6 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const if (model_builder.CreateMLProgram()) { using namespace CoreML::Specification::MILSpec; - // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#module-coremltools.converters.mil.mil.ops.defs.iOS15.elementwise_binary std::string_view coreml_op_type; if (op_type == "Sqrt") { coreml_op_type = "sqrt"; From 4f935e765e1a6d5ce1945f3daeab7e1a7bd519dc Mon Sep 17 00:00:00 2001 From: wejoncy Date: Wed, 18 Sep 2024 01:07:49 -0700 Subject: [PATCH 06/39] unify UT --- .../coreml/builders/impl/base_op_builder.cc | 7 +- .../coreml/builders/impl/builder_utils.cc | 18 +++ .../coreml/builders/impl/unary_op_builder.cc | 7 +- .../coreml/builders/model_builder.cc | 8 + .../providers/coreml/builders/model_builder.h | 3 +- .../providers/coreml/coreml_basic_test.cc | 92 ------------ .../cpu/math/element_wise_ops_test.cc | 137 +++++++++++------- .../apple/coreml_supported_mlprogram_ops.md | 2 + 8 files changed, 123 insertions(+), 151 deletions(-) diff --git a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc index 25d7890faebaf..748fe1dad2267 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc @@ -15,9 +15,9 @@ namespace coreml { // Once all ops are supportted FP16, we can remove it. Before that, we keep a set of ops to // filter suppported ones. -static std::set Float16Ops = { +static std::set Float16Ops = { "Add", "Mul", "Sub", "Div", "Pow", "Sqrt", "Reciprocal", - "Sigmoid", "Tanh", "Relu", "LeakyRelu", "Concat", "GridSample", "GlobalAveragePool", + "Sigmoid", "Tanh", "Relu", "LeakyRelu", "Concat", "GridSample", "GlobalAveragePool", "Clip", "DepthToSpace", "Resize", "Slice", "GlobalMaxPool", "AveragePool", "MaxPool", "Reshape", "Split", "Transpose"}; namespace { @@ -91,7 +91,8 @@ bool BaseOpBuilder::HasSupportedInputs(const Node& node, const OpBuilderInputPar } /* static */ -bool BaseOpBuilder::IsInputDtypeSupport(const Node& node, size_t idx, const OpBuilderInputParams& input_params, +bool BaseOpBuilder::IsInputDtypeSupport(const Node& node, size_t idx, + [[maybe_unused]] const OpBuilderInputParams& input_params, const logging::Logger& logger) { if (idx >= node.InputDefs().size()) { LOGS(logger, VERBOSE) << "Input index [" << idx << "] is out of range"; diff --git a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc index fc6b5792f3649..a27895b6e37f7 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc @@ -131,6 +131,15 @@ void CreateCoreMLWeightConvertingDataToFloats(CoreML::Specification::WeightParam [](T v) { return narrow(v); }); *weight.mutable_floatvalue() = std::move(weight_floats); } + +template +void CreateCoreMLWeightConvertingDataToFloat16s(CoreML::Specification::WeightParams& weight, gsl::span data) { + std::vector weight_float16s{}; + weight_float16s.reserve(data.size()); + std::transform(data.begin(), data.end(), std::back_inserter(weight_float16s), + [](T v) { return MLFloat16(narrow(v)); }); + CreateCoreMLWeight(weight, weight_float16s); +} } // namespace void CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, gsl::span data) { @@ -203,6 +212,13 @@ void CopyDataToTensorValue(MILSpec::TensorValue& tensor_value, gsl::span< tensor_value.mutable_floats()->mutable_values()->Add(data.begin(), data.end()); } +template <> +void CopyDataToTensorValue(MILSpec::TensorValue& tensor_value, gsl::span data) { + const char* begin = (const char*)(data.data()); + const char* end = (const char*)(data.data()) + data.size() * sizeof(MLFloat16); + tensor_value.mutable_bytes()->mutable_values()->assign(begin, end); +} + template <> void CopyDataToTensorValue(MILSpec::TensorValue& tensor_value, gsl::span data) { tensor_value.mutable_ints()->mutable_values()->Add(data.begin(), data.end()); @@ -300,6 +316,8 @@ template MILSpec::Value CreateTensorValue(gsl::span> shape); template MILSpec::Value CreateTensorValue(gsl::span data, std::optional> shape); +template MILSpec::Value CreateTensorValue(gsl::span data, + std::optional> shape); template MILSpec::Value CreateTensorValue(gsl::span data, std::optional> shape); template MILSpec::Value CreateTensorValue(gsl::span data, diff --git a/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc index e8a138aa49799..335ca737081b2 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc @@ -41,7 +41,12 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const AddOperationInput(*op, "x", input_defs[0]->Name()); if (op_type == "Reciprocal") { float epsilon = 1e-4; // epsilon: const T (Optional, default=1e-4) - AddOperationInput(*op, "epsilon", model_builder.AddScalarConstant(op->type(), "epsilon", epsilon)); + auto dtype = node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + if (dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + AddOperationInput(*op, "epsilon", model_builder.AddScalarConstant(op->type(), "epsilon", epsilon)); + } else if (dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + AddOperationInput(*op, "epsilon", model_builder.AddScalarConstant(op->type(), "epsilon", MLFloat16(epsilon))); + } } AddOperationOutput(*op, *node.OutputDefs()[0]); diff --git a/onnxruntime/core/providers/coreml/builders/model_builder.cc b/onnxruntime/core/providers/coreml/builders/model_builder.cc index 7ecfad8493ea5..50faebf06875d 100644 --- a/onnxruntime/core/providers/coreml/builders/model_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/model_builder.cc @@ -639,6 +639,14 @@ std::string_view ModelBuilder::AddConstantImpl(std::string_view op_type, std::st return AddTensorValueAsConstantOperation(op_type, value_type, std::move(input_value)); } +template <> +std::string_view ModelBuilder::AddConstantImpl(std::string_view op_type, std::string_view value_type, + gsl::span value, + std::optional> shape) { + auto input_value = CreateTensorValue(value, shape); + return AddTensorValueAsConstantOperation(op_type, value_type, std::move(input_value)); +} + template <> std::string_view ModelBuilder::AddConstantImpl(std::string_view op_type, std::string_view value_type, gsl::span value, diff --git a/onnxruntime/core/providers/coreml/builders/model_builder.h b/onnxruntime/core/providers/coreml/builders/model_builder.h index bb791fb902908..688dccfc35300 100644 --- a/onnxruntime/core/providers/coreml/builders/model_builder.h +++ b/onnxruntime/core/providers/coreml/builders/model_builder.h @@ -107,11 +107,12 @@ class ModelBuilder { std::string_view AddConstant(std::string_view op_type, std::string_view value_type, gsl::span value, std::optional> shape = std::nullopt) { static_assert(std::is_same_v || + std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v, // add specialization in AddConstantImpl for new types if needed - "AddConstant currently supports float, int64_t, std::string and bool."); + "AddConstant currently supports float/MLFloat16, int64_t, std::string and bool."); return AddConstantImpl(op_type, value_type, value, shape); } diff --git a/onnxruntime/test/providers/coreml/coreml_basic_test.cc b/onnxruntime/test/providers/coreml/coreml_basic_test.cc index 6da88a24bd450..daa24db134114 100644 --- a/onnxruntime/test/providers/coreml/coreml_basic_test.cc +++ b/onnxruntime/test/providers/coreml/coreml_basic_test.cc @@ -257,98 +257,6 @@ TEST(CoreMLExecutionProviderTest, TestNameSanitization) { // TensorRT does not support Clip opset 11 yet. test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } - -TEST(CoreMLExecutionProviderTest, TestBinaryFp16) { - auto test_binary_op = [](std::string op) { - OpTester test(op, 11); - - std::vector dims{3, 3}; - std::vector input1 = {-1.0f, 0.0f, 1.0f, - -6.0f, 0.0f, 6.0f, - -5.4f, 2.0f, 6.0f}; - std::vector input1_fp16(9); - ConvertFloatToMLFloat16(input1.data(), input1_fp16.data(), 9); - std::vector input2 = {-1.0f, 0.0f, 1.0f, - -5.0f, 0.0f, 5.0f, - -5.0f, 2.0f, 5.0f}; - std::vector input2_fp16(9); - ConvertFloatToMLFloat16(input2.data(), input2_fp16.data(), 9); - std::vector output(9); - if (op == "Add") { - for (int i = 0; i < 9; i++) { - output[i] = input1_fp16[i] + input2_fp16[i]; - } - } else if (op == "Sub") { - for (int i = 0; i < 9; i++) { - output[i] = input1_fp16[i] - input2_fp16[i]; - } - } else if (op == "Mul") { - for (int i = 0; i < 9; i++) { - output[i] = input1_fp16[i] * input2_fp16[i]; - } - } else if (op == "Div") { - for (int i = 0; i < 9; i++) { - output[i] = input1_fp16[i] / input2_fp16[i]; - } - } - std::vector output_fp16(9); - ConvertFloatToMLFloat16(output.data(), output_fp16.data(), 9); - - test.AddInput("0", dims, input1_fp16); - test.AddInput("1.min", dims, input2_fp16); - test.AddOutput("3", dims, output_fp16); - - // TensorRT does not support Clip opset 11 yet. - std::vector> coreml_ep; - coreml_ep.emplace_back(MakeCoreMLExecutionProvider(COREML_FLAG_CREATE_MLPROGRAM)); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &coreml_ep); - }; - test_binary_op("Add"); - test_binary_op("Sub"); - test_binary_op("Div"); - test_binary_op("Mul"); -} - -TEST(CoreMLExecutionProviderTest, TestUnaryFp16) { - auto test_binary_op = [](std::string op) { - OpTester test(op, 11); - - std::vector dims{3, 3}; - std::vector input1 = {-1.0f, 0.0f, 1.0f, - -6.0f, 0.2f, 6.0f, - -5.4f, 2.0f, 6.0f}; - std::vector input1_fp16(9); - ConvertFloatToMLFloat16(input1.data(), input1_fp16.data(), 9); - - std::vector output(9); - if (op == "Sqrt") { - for (int i = 0; i < 9; i++) { - output[i] = sqrt(input1_fp16[i]); - } - } else if (op == "Reciprocal") { - for (int i = 0; i < 9; i++) { - output[i] = 1.0f / (1e-4 + input1_fp16[i]); - } - } else if (op == "Relu") { - for (int i = 0; i < 9; i++) { - output[i] = fmax(0.0f, input1_fp16[i]); - } - } - std::vector output_fp16(9); - ConvertFloatToMLFloat16(output.data(), output_fp16.data(), 9); - - test.AddInput("0", dims, input1_fp16); - test.AddOutput("3", dims, output_fp16); - - // TensorRT does not support Clip opset 11 yet. - std::vector> coreml_ep; - coreml_ep.emplace_back(MakeCoreMLExecutionProvider(COREML_FLAG_CREATE_MLPROGRAM)); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &coreml_ep); - }; - test_binary_op("Sqrt"); - test_binary_op("Reciprocal"); - test_binary_op("Relu"); -} #endif } // namespace test diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index bd3d21d4929f3..659622a70e4cc 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -22,26 +22,38 @@ std::vector MakeMLFloat16(const std::initializer_list& input) return output; } -#if defined(USE_CUDA) || defined(USE_ROCM) -void TestFloat16(const char* op_name, const std::vector& lhs_dim, - const std::initializer_list& lhs_values, const std::vector& rhs_dim, - const std::initializer_list& rhs_values, const std::vector& out_dim, - const std::initializer_list& out_values) { +void TestBinaryFloat16(const char* op_name, const std::vector& lhs_dim, + const std::initializer_list& lhs_values, const std::vector& rhs_dim, + const std::initializer_list& rhs_values, const std::vector& out_dim, + const std::initializer_list& out_values, bool enable_bf16 = true) { + ORT_UNUSED_PARAMETER(op_name); + ORT_UNUSED_PARAMETER(lhs_dim); + ORT_UNUSED_PARAMETER(lhs_values); + ORT_UNUSED_PARAMETER(rhs_dim); + ORT_UNUSED_PARAMETER(rhs_values); + ORT_UNUSED_PARAMETER(out_dim); + ORT_UNUSED_PARAMETER(out_values); + ORT_UNUSED_PARAMETER(enable_bf16); +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(COREML_ENABLE_MLPROGRAM) { OpTester tester(op_name, 14); tester.AddInput("A", lhs_dim, MakeMLFloat16(lhs_values)); tester.AddInput("B", rhs_dim, MakeMLFloat16(rhs_values)); tester.AddOutput("C", out_dim, MakeMLFloat16(out_values)); std::vector> execution_providers; -#ifdef USE_CUDA +#ifdef COREML_ENABLE_MLPROGRAM + execution_providers.push_back(DefaultCoreMLExecutionProvider(true)); +#elif USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); #elif USE_ROCM execution_providers.push_back(DefaultRocmExecutionProvider()); #endif tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } +#endif - { +#if defined(USE_CUDA) || defined(USE_ROCM) + if (enable_bf16) { OpTester tester(op_name, 14); tester.AddInput("A", lhs_dim, MakeBFloat16(lhs_values)); tester.AddInput("B", rhs_dim, MakeBFloat16(rhs_values)); @@ -54,9 +66,52 @@ void TestFloat16(const char* op_name, const std::vector& lhs_dim, #endif tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } +#endif } + +void TestUnaryFloat16(const char* op_name, const std::vector& lhs_dim, + const std::initializer_list& lhs_values, const std::vector& out_dim, + const std::initializer_list& out_values, int opset = 14) { + ORT_UNUSED_PARAMETER(op_name); + ORT_UNUSED_PARAMETER(lhs_dim); + ORT_UNUSED_PARAMETER(lhs_values); + ORT_UNUSED_PARAMETER(rhs_dim); + ORT_UNUSED_PARAMETER(out_dim); + ORT_UNUSED_PARAMETER(out_values); + ORT_UNUSED_PARAMETER(opset); +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(COREML_ENABLE_MLPROGRAM) + { + OpTester tester(op_name, opset); + tester.AddInput("A", lhs_dim, MakeMLFloat16(lhs_values)); + tester.AddOutput("C", out_dim, MakeMLFloat16(out_values)); + std::vector> execution_providers; +#ifdef COREML_ENABLE_MLPROGRAM + execution_providers.push_back(DefaultCoreMLExecutionProvider(true)); +#elif USE_CUDA + execution_providers.push_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + execution_providers.push_back(DefaultRocmExecutionProvider()); +#endif + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } #endif +#if defined(USE_CUDA) || defined(USE_ROCM) + { + OpTester tester(op_name, opset); + tester.AddInput("A", lhs_dim, MakeBFloat16(lhs_values)); + tester.AddOutput("C", out_dim, MakeBFloat16(out_values)); + std::vector> execution_providers; +#ifdef USE_CUDA + execution_providers.push_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + execution_providers.push_back(DefaultRocmExecutionProvider()); +#endif + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +#endif +} + void TestBFloat16(const char* op_name, const std::vector& lhs_dim, const std::initializer_list& lhs_values, const std::vector& rhs_dim, const std::initializer_list& rhs_values, const std::vector& out_dim, @@ -163,9 +218,7 @@ TEST(MathOpTest, Add_float) { test.Run(); #endif -#if defined(USE_CUDA) || defined(USE_ROCM) - TestFloat16("Add", dims, lhs_values, dims, rhs_values, dims, out_values); -#endif + TestBinaryFloat16("Add", dims, lhs_values, dims, rhs_values, dims, out_values); #if defined(USE_DNNL) TestBFloat16("Add", dims, lhs_values, dims, rhs_values, dims, out_values); @@ -202,9 +255,7 @@ TEST(MathOpTest, Add_Broadcast_Axis) { test.AddOutput("C", dims, out_values); test.Run(OpTester::ExpectResult::kExpectSuccess, ""); -#if defined(USE_CUDA) || defined(USE_ROCM) - TestFloat16("Add", dims, lhs_values, {3, 1}, rhs_values, dims, out_values); -#endif + TestBinaryFloat16("Add", dims, lhs_values, {3, 1}, rhs_values, dims, out_values); #if defined(USE_DNNL) TestBFloat16("Add", dims, lhs_values, {3, 1}, rhs_values, dims, out_values); @@ -228,9 +279,7 @@ TEST(MathOpTest, Add_Broadcast_MultidirectionalAB) { {kTensorrtExecutionProvider}); // TensorRT: got C with shape [3, 1] #endif -#if defined(USE_CUDA) || defined(USE_ROCM) - TestFloat16("Add", {3, 1}, lhs_values, {3}, rhs_values, {3, 3}, out_values); -#endif + TestBinaryFloat16("Add", {3, 1}, lhs_values, {3}, rhs_values, {3, 3}, out_values); #if defined(USE_DNNL) TestBFloat16("Add", {3, 1}, lhs_values, {3}, rhs_values, {3, 3}, out_values); @@ -254,9 +303,7 @@ TEST(MathOpTest, Add_Broadcast_MultidirectionalBA) { {kTensorrtExecutionProvider}); // TensorRT: got C with shape [3, 1] #endif -#if defined(USE_CUDA) || defined(USE_ROCM) - TestFloat16("Add", {3}, lhs_values, {3, 1}, rhs_values, {3, 3}, out_values); -#endif + TestBinaryFloat16("Add", {3}, lhs_values, {3, 1}, rhs_values, {3, 3}, out_values); #if defined(USE_DNNL) TestBFloat16("Add", {3}, lhs_values, {3, 1}, rhs_values, {3, 3}, out_values); @@ -527,9 +574,7 @@ TEST(MathOpTest, Sub) { test.AddOutput("C", dims, out_values); test.Run(); -#if defined(USE_CUDA) || defined(USE_ROCM) - TestFloat16("Sub", dims, lhs_values, dims, rhs_values, dims, out_values); -#endif + TestBinaryFloat16("Sub", dims, lhs_values, dims, rhs_values, dims, out_values); #if defined(USE_DNNL) TestBFloat16("Sub", dims, lhs_values, dims, rhs_values, dims, out_values); @@ -584,9 +629,7 @@ TEST(MathOpTest, Mul) { test.Run(); -#if defined(USE_CUDA) || defined(USE_ROCM) - TestFloat16("Mul", dims, lhs_values, dims, rhs_values, dims, out_values); -#endif + TestBinaryFloat16("Mul", dims, lhs_values, dims, rhs_values, dims, out_values); #if defined(USE_DNNL) TestBFloat16("Mul", dims, lhs_values, dims, rhs_values, dims, out_values); @@ -622,9 +665,7 @@ TEST(MathOpTest, Div) { test.AddOutput("C", dims, out_values); test.Run(); -#if defined(USE_CUDA) || defined(USE_ROCM) - TestFloat16("Div", dims, lhs_values, dims, rhs_values, dims, out_values); -#endif + TestBinaryFloat16("Div", dims, lhs_values, dims, rhs_values, dims, out_values); #if defined(USE_DNNL) TestBFloat16("Div", dims, lhs_values, dims, rhs_values, dims, out_values); @@ -772,13 +813,12 @@ TEST(MathOpTest, Ceil_double) { TEST(MathOpTest, Reciprocal) { OpTester test("Reciprocal"); std::vector dims{2, 2}; - test.AddInput("X", dims, - {1.0f, 2.0f, - -1.0f, -2.0f}); - test.AddOutput("Y", dims, - {1.0f, 0.5f, - -1.0f, -0.5f}); + std::initializer_list inputs = {1.0f, 2.0f, -1.0f, -2.0f}; + std::initializer_list outputs = {1.0f, 0.5f, -1.0f, -0.5f}; + test.AddInput("X", dims, inputs); + test.AddOutput("Y", dims, outputs); test.Run(); + TestUnaryFloat16("Reciprocal", dims, inputs, dims, outputs, 12); } TEST(MathOpTest, Reciprocal_double) { @@ -795,14 +835,13 @@ TEST(MathOpTest, Reciprocal_double) { TEST(MathOpTest, Sqrt_Float) { OpTester test("Sqrt"); + std::initializer_list inputs = {1.0f, 4.0f, 0.0f, 9.0f}; + std::initializer_list outputs = {1.0f, 2.0f, 0.0f, 3.0f}; std::vector dims{2, 2}; - test.AddInput("X", dims, - {1.0f, 4.0f, - 0.0f, 9.0f}); - test.AddOutput("Y", dims, - {1.0f, 2.0f, - 0.0f, 3.0f}); + test.AddInput("X", dims, inputs); + test.AddOutput("Y", dims, outputs); test.Run(); + TestUnaryFloat16("Sqrt", dims, inputs, dims, outputs); } #if defined(USE_DNNL) || defined(USE_CUDA) @@ -1056,24 +1095,13 @@ TEST(MathOpTest, Pow_double_int64) { test.Run(); } -#if defined(USE_CUDA) || defined(USE_ROCM) TEST(MathOpTest, Pow_float16_float16) { - OpTester test("Pow", 12); std::vector dims{4}; - - test.AddInput("X", dims, MakeMLFloat16({2.0f, 2.0f, std::sqrt(2.0f), 1.0f})); - test.AddInput("Y", dims, MakeMLFloat16({0.0f, 8.0f, 2.0f, 9.0f})); - test.AddOutput("Z", dims, MakeMLFloat16({1.0f, 256.0f, 2.0f, 1.0f})); - - std::vector> execution_providers; -#ifdef USE_CUDA - execution_providers.push_back(DefaultCudaExecutionProvider()); -#elif USE_ROCM - execution_providers.push_back(DefaultRocmExecutionProvider()); -#endif - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + TestBinaryFloat16("Pow", dims, {2.0f, 2.0f, std::sqrt(2.0f), 1.0f}, dims, {0.0f, 8.0f, 2.0f, 9.0f}, + dims, {1.0f, 256.0f, 2.0f, 1.0f}, false); } +#if defined(USE_CUDA) || defined(USE_ROCM) TEST(MathOpTest, Pow_float_float16) { OpTester test("Pow", 12); std::vector dims{4}; @@ -3660,5 +3688,6 @@ TEST(MathOpTest, BitwiseNot_uint8) { test.AddOutput("Y", dims, {254, 251, 250, 252}); test.Run(); } + } // namespace test } // namespace onnxruntime diff --git a/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md b/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md index bb4cfb2e09dcc..0b51311e2271f 100644 --- a/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md +++ b/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md @@ -29,3 +29,5 @@ Keep in sync with doco generated from /docs/execution-providers/CoreML-Execution |ai.onnx:Sigmoid|| |ai.onnx:Tanh|| |ai.onnx:Transpose|| +|ai.onnx:Sqrt|| +|ai.onnx:Reciprocal|this ask for a `epislon` (default 1e-4) where onnx don't provide| From d76509585dac0527b8a5c26cd0c845ad5ec94c15 Mon Sep 17 00:00:00 2001 From: wejoncy Date: Fri, 20 Sep 2024 00:53:19 -0700 Subject: [PATCH 07/39] gemm conv support --- .../coreml/builders/impl/base_op_builder.cc | 6 +- .../coreml/builders/impl/gemm_op_builder.cc | 69 ++++++++++++------- .../providers/coreml/builders/model_builder.h | 21 ++++++ .../cpu/math/element_wise_ops_test.cc | 1 - 4 files changed, 69 insertions(+), 28 deletions(-) diff --git a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc index 748fe1dad2267..68eaa7a34c074 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc @@ -17,8 +17,10 @@ namespace coreml { // filter suppported ones. static std::set Float16Ops = { "Add", "Mul", "Sub", "Div", "Pow", "Sqrt", "Reciprocal", - "Sigmoid", "Tanh", "Relu", "LeakyRelu", "Concat", "GridSample", "GlobalAveragePool", "Clip", "DepthToSpace", "Resize", "Slice", - "GlobalMaxPool", "AveragePool", "MaxPool", "Reshape", "Split", "Transpose"}; + "Sigmoid", "Tanh", "Relu", "LeakyRelu", "Concat", "GridSample", "GlobalAveragePool", + "Clip", "DepthToSpace", "Resize", "Slice", "Conv", + "ConvTranspose", "GlobalMaxPool", + "AveragePool", "MaxPool", "Reshape", "Split", "Transpose"}; namespace { // TODO, move this to shared_library diff --git a/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc index 7338fc18fe779..bb0b88c1f3607 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc @@ -70,19 +70,35 @@ void GemmOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Nod } } -// This is an internal function, requires input tensor to be 2d float tensor +// This is an internal function, requires input tensor to be 2d float/float16 tensor // TODO, add support of other data types -static Status GetTensorFloatDataTransposed(const ONNX_NAMESPACE::TensorProto& tensor, - std::vector& transposed_data) { +// Template will make the binary size inflation, and which woundn't affect the runtime performance. +static Status GetTensorDataTransposed(const ONNX_NAMESPACE::TensorProto& tensor, + std::vector& transposed_data, + int dtype = ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { Initializer unpacked_tensor(tensor); - auto src_data = unpacked_tensor.DataAsSpan(); + const void* src_dataraw = unpacked_tensor.DataAsByteSpan().data(); const auto& tensor_shape = tensor.dims(); auto x_t = SafeInt(tensor_shape[0]); auto y_t = SafeInt(tensor_shape[1]); - transposed_data.resize(x_t * y_t); - for (size_t x = 0; x < x_t; x++) { - for (size_t y = 0; y < y_t; y++) { - transposed_data[y * x_t + x] = src_data[x * y_t + y]; + size_t bytes_in_type = dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 ? sizeof(MLFloat16) : sizeof(float); + transposed_data.resize(x_t * y_t * bytes_in_type); + void* dst_data = transposed_data.data(); + if (dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + MLFloat16* dst_ptr = (MLFloat16*)dst_data; + const MLFloat16* src_ptr = (const MLFloat16*)src_dataraw; + for (size_t x = 0; x < x_t; x++) { + for (size_t y = 0; y < y_t; y++) { + dst_ptr[y * x_t + x] = src_ptr[x * y_t + y]; + } + } + } else { + float* dst_ptr = (float*)dst_data; + const float* src_ptr = (const float*)src_dataraw; + for (size_t x = 0; x < x_t; x++) { + for (size_t y = 0; y < y_t; y++) { + dst_ptr[y * x_t + x] = src_ptr[x * y_t + y]; + } } } @@ -121,7 +137,8 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N // B is {K, N} in ONNX spec by default, or {N, K} in Gemm if transB is true const auto K = transB ? b1 : b0; const auto N = transB ? b0 : b1; - + // we already checked it and dtype must be existed. + auto input_dtype = node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); #if defined(COREML_ENABLE_MLPROGRAM) if (model_builder.CreateMLProgram()) { using namespace CoreML::Specification::MILSpec; @@ -137,12 +154,11 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N AddOperationInput(*gemm_op, "weight", b.Name()); } else { // transpose from {K, N} to {N, K} - std::vector weight_nk; + std::vector weight_nk; // use bytes to store the type-erased data, could be any data-type std::vector weight_nk_shape = {N, K}; - ORT_RETURN_IF_ERROR(GetTensorFloatDataTransposed(*b_initializer, weight_nk)); - + ORT_RETURN_IF_ERROR(GetTensorDataTransposed(*b_initializer, weight_nk, input_dtype)); AddOperationInput(*gemm_op, "weight", - model_builder.AddConstant(gemm_op->type(), b.Name() + "_t", weight_nk, weight_nk_shape)); + model_builder.AddConstant(gemm_op->type(), b.Name() + "_t", weight_nk, input_dtype, weight_nk_shape)); } if (input_defs.size() == 3) { @@ -155,17 +171,19 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N AddOperationInput(*gemm_op, "bias", bias_arg.Name()); } else { Initializer unpacked_tensor(bias); - auto bias_data = unpacked_tensor.DataAsSpan(); - std::string_view bias_data_name; - if (bias_data.size() == 1) { + std::vector no_typed_bias_data; + gsl::span bias_data_span = unpacked_tensor.DataAsByteSpan(); + size_t bytes_in_type = input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT ? sizeof(float) : sizeof(MLFloat16); + // can use data as-is but need to adjust shape (inferred by AddConstant as {bias_data.size()}) + if (bias_data_span.size() == bytes_in_type) { // expand scalar to N - std::vector expanded_bias_data(N, bias_data[0]); - bias_data_name = model_builder.AddConstant(gemm_op->type(), "bias", expanded_bias_data); - } else { - // can use data as-is but need to adjust shape (inferred by AddConstant as {bias_data.size()}) - bias_data_name = model_builder.AddConstant(gemm_op->type(), "bias", bias_data); + no_typed_bias_data.resize(N * bytes_in_type, 0); + for (int64_t i = 0; i < N; i++) { + std::copy_n(bias_data_span.data() + i * bytes_in_type, bytes_in_type, no_typed_bias_data.data() + i * bytes_in_type); + } + bias_data_span = AsSpan(no_typed_bias_data); } - + std::string_view bias_data_name = model_builder.AddConstant(gemm_op->type(), "bias", bias_data_span, input_dtype); AddOperationInput(*gemm_op, "bias", bias_data_name); } } @@ -201,9 +219,10 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N if (transB) { ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_inner_product->mutable_weights(), *b_initializer)); } else { - std::vector b_transposed; - ORT_RETURN_IF_ERROR(GetTensorFloatDataTransposed(*b_initializer, b_transposed)); - CreateCoreMLWeight(*coreml_inner_product->mutable_weights(), b_transposed); + std::vector b_transposed; + ORT_RETURN_IF_ERROR(GetTensorDataTransposed(*b_initializer, b_transposed)); + gsl::span b_transposed_f((const float*)(b_transposed.data()), b_transposed.size() / sizeof(float)); + CreateCoreMLWeight(*coreml_inner_product->mutable_weights(), b_transposed_f); } if (is_gemm && input_defs.size() > 2) { diff --git a/onnxruntime/core/providers/coreml/builders/model_builder.h b/onnxruntime/core/providers/coreml/builders/model_builder.h index 688dccfc35300..f94521ac91979 100644 --- a/onnxruntime/core/providers/coreml/builders/model_builder.h +++ b/onnxruntime/core/providers/coreml/builders/model_builder.h @@ -116,6 +116,27 @@ class ModelBuilder { return AddConstantImpl(op_type, value_type, value, shape); } + std::string_view AddConstant(std::string_view op_type, std::string_view value_type, gsl::span value, + int input_dtype, std::optional> shape = std::nullopt) { + switch (input_dtype) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: { + gsl::span value_v((const float*)(value.data()), value.size() / sizeof(float)); + return AddConstant(op_type, value_type, value_v, shape); + } + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: { + gsl::span value_v((const MLFloat16*)(value.data()), value.size() / sizeof(MLFloat16)); + return AddConstant(op_type, value_type, value_v, shape); + } + case ONNX_NAMESPACE::TensorProto_DataType_INT64: { + gsl::span value_v((const int64_t*)(value.data()), value.size() / sizeof(int64_t)); + return AddConstant(op_type, value_type, value_v, shape); + } + default: + ORT_ENFORCE(false, "unspported data type"); + break; + } + } + template std::string_view AddConstant(std::string_view op_type, std::string_view value_type, const std::vector& value, std::optional> shape = std::nullopt) { diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index 659622a70e4cc..275194720e850 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -75,7 +75,6 @@ void TestUnaryFloat16(const char* op_name, const std::vector& lhs_dim, ORT_UNUSED_PARAMETER(op_name); ORT_UNUSED_PARAMETER(lhs_dim); ORT_UNUSED_PARAMETER(lhs_values); - ORT_UNUSED_PARAMETER(rhs_dim); ORT_UNUSED_PARAMETER(out_dim); ORT_UNUSED_PARAMETER(out_values); ORT_UNUSED_PARAMETER(opset); From 812964307b9903ef76393f06c219d29caee32c18 Mon Sep 17 00:00:00 2001 From: wejoncy Date: Fri, 20 Sep 2024 02:05:10 -0700 Subject: [PATCH 08/39] gemm/conv test --- .../coreml/builders/impl/base_op_builder.cc | 2 +- .../coreml/builders/impl/gemm_op_builder.cc | 2 +- .../test/providers/cpu/math/gemm_test.cc | 80 +++++++++++++++---- .../test/providers/cpu/math/matmul_test.cc | 2 +- .../test/providers/cpu/nn/conv_fp16_test.cc | 2 +- 5 files changed, 70 insertions(+), 18 deletions(-) diff --git a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc index 68eaa7a34c074..300e03ed1f39b 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc @@ -19,7 +19,7 @@ static std::set Float16Ops = { "Add", "Mul", "Sub", "Div", "Pow", "Sqrt", "Reciprocal", "Sigmoid", "Tanh", "Relu", "LeakyRelu", "Concat", "GridSample", "GlobalAveragePool", "Clip", "DepthToSpace", "Resize", "Slice", "Conv", - "ConvTranspose", "GlobalMaxPool", + "ConvTranspose", "GlobalMaxPool", "Gemm", "MatMul", "AveragePool", "MaxPool", "Reshape", "Split", "Transpose"}; namespace { diff --git a/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc index bb0b88c1f3607..958490e484108 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc @@ -179,7 +179,7 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N // expand scalar to N no_typed_bias_data.resize(N * bytes_in_type, 0); for (int64_t i = 0; i < N; i++) { - std::copy_n(bias_data_span.data() + i * bytes_in_type, bytes_in_type, no_typed_bias_data.data() + i * bytes_in_type); + std::copy_n(bias_data_span.data(), bytes_in_type, no_typed_bias_data.data() + i * bytes_in_type); } bias_data_span = AsSpan(no_typed_bias_data); } diff --git a/onnxruntime/test/providers/cpu/math/gemm_test.cc b/onnxruntime/test/providers/cpu/math/gemm_test.cc index 7ec84d87b2a8b..fd0b0bd5bc170 100644 --- a/onnxruntime/test/providers/cpu/math/gemm_test.cc +++ b/onnxruntime/test/providers/cpu/math/gemm_test.cc @@ -25,7 +25,7 @@ const constexpr auto run_with_tunable_op = &run_options; } // namespace -// Only CUDA and ROCM kernel has float 16 support +// Only CUDA and ROCM/CoreML kernel has float 16 support TEST(GemmOpTest, GemmNoTrans_f16) { #ifdef USE_CUDA int min_cuda_architecture = 530; @@ -34,12 +34,6 @@ TEST(GemmOpTest, GemmNoTrans_f16) { return; } #endif - OpTester test("Gemm", 13); - - test.AddAttribute("transA", (int64_t)0); - test.AddAttribute("transB", (int64_t)0); - test.AddAttribute("alpha", 1.0f); - test.AddAttribute("beta", 1.0f); std::vector A{1.0f, 2.0f, 3.0f, 4.0f, -1.0f, -2.0f, -3.0f, -4.0f}; @@ -57,13 +51,71 @@ TEST(GemmOpTest, GemmNoTrans_f16) { ConvertFloatToMLFloat16(C.data(), f_C.data(), 6); ConvertFloatToMLFloat16(Y.data(), f_Y.data(), 6); - test.AddInput("A", {2, 4}, f_A); - test.AddInput("B", {4, 3}, f_B); - test.AddInput("C", {2, 3}, f_C); - test.AddOutput("Y", {2, 3}, f_Y); - test.ConfigExcludeEps({kTensorrtExecutionProvider}) // TensorRT: fp16 is not supported - .Config(run_with_tunable_op) - .RunWithConfig(); + { + OpTester test("Gemm", 13); + + test.AddAttribute("transA", (int64_t)0); + test.AddAttribute("transB", (int64_t)0); + test.AddAttribute("alpha", 1.0f); + test.AddAttribute("beta", 1.0f); + test.AddInput("A", {2, 4}, f_A); + + test.AddInput("B", {4, 3}, f_B); + test.AddInput("C", {2, 3}, f_C); + test.AddOutput("Y", {2, 3}, f_Y); + test.ConfigExcludeEps({kTensorrtExecutionProvider}) // TensorRT: fp16 is not supported + .Config(run_with_tunable_op) + .RunWithConfig(); + } + { + // CoreML program require B/C are constant + OpTester test("Gemm", 13); + + test.AddAttribute("transA", (int64_t)0); + test.AddAttribute("transB", (int64_t)0); + test.AddAttribute("alpha", 1.0f); + test.AddAttribute("beta", 1.0f); + test.AddInput("A", {2, 4}, f_A); + test.AddInput("B", {4, 3}, f_B, true); + f_C.resize(3); + test.AddInput("C", {3}, f_C, true); + test.AddOutput("Y", {2, 3}, f_Y); + test.ConfigExcludeEps({kTensorrtExecutionProvider}) // TensorRT: fp16 is not supported + .Config(run_with_tunable_op) + .RunWithConfig(); + } + { + OpTester test("Gemm", 13); + + test.AddAttribute("transA", (int64_t)0); + test.AddAttribute("transB", (int64_t)0); + test.AddAttribute("alpha", 1.0f); + test.AddAttribute("beta", 1.0f); + test.AddInput("A", {2, 4}, f_A); + test.AddInput("B", {4, 3}, f_B, true); + f_C.resize(1); + test.AddInput("C", {1}, f_C, true); + test.AddOutput("Y", {2, 3}, f_Y); + test.ConfigExcludeEps({kTensorrtExecutionProvider}) // TensorRT: fp16 is not supported + .Config(run_with_tunable_op) + .RunWithConfig(); + } + { + OpTester test("Gemm", 13); + + test.AddAttribute("transA", (int64_t)0); + test.AddAttribute("transB", (int64_t)1); + test.AddAttribute("alpha", 1.0f); + test.AddAttribute("beta", 1.0f); + test.AddInput("A", {2, 4}, f_A); + test.AddInput("B", {3, 4}, f_B, true); + f_C.resize(1); + test.AddInput("C", {1}, f_C, true); + test.AddOutput("Y", {2, 3}, f_Y); + test.ConfigExcludeEps({kTensorrtExecutionProvider}) // TensorRT: fp16 is not supported + .Config(run_with_tunable_op) + .RunWithConfig(); + } } #if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DNNL) diff --git a/onnxruntime/test/providers/cpu/math/matmul_test.cc b/onnxruntime/test/providers/cpu/math/matmul_test.cc index 90370560859aa..a7d2281ac19f8 100644 --- a/onnxruntime/test/providers/cpu/math/matmul_test.cc +++ b/onnxruntime/test/providers/cpu/math/matmul_test.cc @@ -246,7 +246,7 @@ TEST(MathOpTest, MatMulZeroKInt32Type) { RunMatMulZeroKTest(); } -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(COREML_ENABLE_MLPROGRAM) TEST(MathOpTest, MatMul_Float16) { #ifdef USE_CUDA int min_cuda_architecture = 530; diff --git a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc index 95b274966fbbb..b9e8e4ee13132 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc @@ -3,7 +3,7 @@ #include "core/mlas/inc/mlas.h" -#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM) #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" From 9d665f3c857a0700a9cae4b173ef6beea7a15b7d Mon Sep 17 00:00:00 2001 From: wejoncy Date: Fri, 20 Sep 2024 02:09:29 -0700 Subject: [PATCH 09/39] address comments --- .../cpu/math/element_wise_ops_test.cc | 29 +++++++++---------- .../apple/coreml_supported_mlprogram_ops.md | 4 +-- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index 275194720e850..06fd837d22046 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -22,18 +22,14 @@ std::vector MakeMLFloat16(const std::initializer_list& input) return output; } -void TestBinaryFloat16(const char* op_name, const std::vector& lhs_dim, - const std::initializer_list& lhs_values, const std::vector& rhs_dim, - const std::initializer_list& rhs_values, const std::vector& out_dim, - const std::initializer_list& out_values, bool enable_bf16 = true) { - ORT_UNUSED_PARAMETER(op_name); - ORT_UNUSED_PARAMETER(lhs_dim); - ORT_UNUSED_PARAMETER(lhs_values); - ORT_UNUSED_PARAMETER(rhs_dim); - ORT_UNUSED_PARAMETER(rhs_values); - ORT_UNUSED_PARAMETER(out_dim); - ORT_UNUSED_PARAMETER(out_values); - ORT_UNUSED_PARAMETER(enable_bf16); +void TestBinaryFloat16([[maybe_unused]] const char* op_name, + [[maybe_unused]] const std::vector& lhs_dim, + [[maybe_unused]] const std::initializer_list& lhs_values, + [[maybe_unused]] const std::vector& rhs_dim, + [[maybe_unused]] const std::initializer_list& rhs_values, + [[maybe_unused]] const std::vector& out_dim, + [[maybe_unused]] const std::initializer_list& out_values, + [[maybe_unused]] bool enable_bf16 = true) { #if defined(USE_CUDA) || defined(USE_ROCM) || defined(COREML_ENABLE_MLPROGRAM) { OpTester tester(op_name, 14); @@ -69,9 +65,12 @@ void TestBinaryFloat16(const char* op_name, const std::vector& lhs_dim, #endif } -void TestUnaryFloat16(const char* op_name, const std::vector& lhs_dim, - const std::initializer_list& lhs_values, const std::vector& out_dim, - const std::initializer_list& out_values, int opset = 14) { +void TestUnaryFloat16([[maybe_unused]] const char* op_name, + [[maybe_unused]] const std::vector& lhs_dim, + [[maybe_unused]] const std::initializer_list& lhs_values, + [[maybe_unused]] const std::vector& out_dim, + [[maybe_unused]] const std::initializer_list& out_values, + [[maybe_unused]] int opset = 14) { ORT_UNUSED_PARAMETER(op_name); ORT_UNUSED_PARAMETER(lhs_dim); ORT_UNUSED_PARAMETER(lhs_values); diff --git a/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md b/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md index 0b51311e2271f..ae0769e7fb93c 100644 --- a/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md +++ b/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md @@ -20,6 +20,7 @@ Keep in sync with doco generated from /docs/execution-providers/CoreML-Execution |ai.onnx:MaxPool|Only 2D Pool is supported currently. 3D and 5D support can be added if needed.| |ai.onnx:Mul|| |ai.onnx:Pow|Only supports cases when both inputs are fp32.| +|ai.onnx:Reciprocal|this ask for a `epislon` (default 1e-4) where onnx don't provide| |ai.onnx:Relu|| |ai.onnx:Reshape|| |ai.onnx:Resize|See [resize_op_builder.cc](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc) implementation. There are too many permutations to describe the valid combinations.| @@ -27,7 +28,6 @@ Keep in sync with doco generated from /docs/execution-providers/CoreML-Execution |ai.onnx:Split|If provided, `splits` must be constant.| |ai.onnx:Sub|| |ai.onnx:Sigmoid|| +|ai.onnx:Sqrt|| |ai.onnx:Tanh|| |ai.onnx:Transpose|| -|ai.onnx:Sqrt|| -|ai.onnx:Reciprocal|this ask for a `epislon` (default 1e-4) where onnx don't provide| From dbf25b97f6667b6af140dfb79113cf092bebb479 Mon Sep 17 00:00:00 2001 From: wejoncy Date: Fri, 20 Sep 2024 04:58:55 -0700 Subject: [PATCH 10/39] build issue --- .../coreml/builders/impl/gemm_op_builder.cc | 2 +- .../providers/cpu/math/element_wise_ops_test.cc | 13 ++++--------- onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc | 3 +++ 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc index 958490e484108..535f546840073 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc @@ -138,8 +138,8 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N const auto K = transB ? b1 : b0; const auto N = transB ? b0 : b1; // we already checked it and dtype must be existed. - auto input_dtype = node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); #if defined(COREML_ENABLE_MLPROGRAM) + auto input_dtype = node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); if (model_builder.CreateMLProgram()) { using namespace CoreML::Specification::MILSpec; diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index 06fd837d22046..d030fc2ec85e9 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -70,13 +70,8 @@ void TestUnaryFloat16([[maybe_unused]] const char* op_name, [[maybe_unused]] const std::initializer_list& lhs_values, [[maybe_unused]] const std::vector& out_dim, [[maybe_unused]] const std::initializer_list& out_values, - [[maybe_unused]] int opset = 14) { - ORT_UNUSED_PARAMETER(op_name); - ORT_UNUSED_PARAMETER(lhs_dim); - ORT_UNUSED_PARAMETER(lhs_values); - ORT_UNUSED_PARAMETER(out_dim); - ORT_UNUSED_PARAMETER(out_values); - ORT_UNUSED_PARAMETER(opset); + [[maybe_unused]] int opset = 14, + [[maybe_unused]] bool run_bf16 = true) { #if defined(USE_CUDA) || defined(USE_ROCM) || defined(COREML_ENABLE_MLPROGRAM) { OpTester tester(op_name, opset); @@ -95,7 +90,7 @@ void TestUnaryFloat16([[maybe_unused]] const char* op_name, #endif #if defined(USE_CUDA) || defined(USE_ROCM) - { + if (run_bf16) { OpTester tester(op_name, opset); tester.AddInput("A", lhs_dim, MakeBFloat16(lhs_values)); tester.AddOutput("C", out_dim, MakeBFloat16(out_values)); @@ -816,7 +811,7 @@ TEST(MathOpTest, Reciprocal) { test.AddInput("X", dims, inputs); test.AddOutput("Y", dims, outputs); test.Run(); - TestUnaryFloat16("Reciprocal", dims, inputs, dims, outputs, 12); + TestUnaryFloat16("Reciprocal", dims, inputs, dims, outputs, 12, false); } TEST(MathOpTest, Reciprocal_double) { diff --git a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc index b9e8e4ee13132..6bb67d7e0185b 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc @@ -45,6 +45,9 @@ void TestConvFp16Op(const ConvOpAndTestAttributes& attributes, if (!attributes.activation_parameters.empty()) { tester->AddAttribute("activation_params", attributes.activation_parameters); } +#if !defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) + return; +#endif } else { tester = std::make_unique("Conv", opset); } From 293b9f2ee3ea51859e941d92704c9fe45b659096 Mon Sep 17 00:00:00 2001 From: wejoncy Date: Sat, 21 Sep 2024 01:29:28 -0700 Subject: [PATCH 11/39] fix crash test --- onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc index 6bb67d7e0185b..895a0f3ea1c3b 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc @@ -37,6 +37,12 @@ void TestConvFp16Op(const ConvOpAndTestAttributes& attributes, OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess, const std::string& err_str = "", int opset = 11) { +#if !defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) +// a `return` after tester will make binary crash + if (!attributes.activation.empty()) { + return; + } +#endif std::unique_ptr tester; if (!attributes.activation.empty()) { tester = std::make_unique("NhwcFusedConv", 1, onnxruntime::kMSDomain); @@ -45,9 +51,6 @@ void TestConvFp16Op(const ConvOpAndTestAttributes& attributes, if (!attributes.activation_parameters.empty()) { tester->AddAttribute("activation_params", attributes.activation_parameters); } -#if !defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) - return; -#endif } else { tester = std::make_unique("Conv", opset); } From 4b344a5ca65a902bbfd0b51897d2a45754137be5 Mon Sep 17 00:00:00 2001 From: wejoncy Date: Sat, 21 Sep 2024 04:33:50 -0700 Subject: [PATCH 12/39] lint --- onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc index 895a0f3ea1c3b..77716e5e370c4 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc @@ -38,9 +38,9 @@ void TestConvFp16Op(const ConvOpAndTestAttributes& attributes, const std::string& err_str = "", int opset = 11) { #if !defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) -// a `return` after tester will make binary crash + // a `return` after tester will make binary crash if (!attributes.activation.empty()) { - return; + return; } #endif std::unique_ptr tester; From ca581bc6c121ec8e9fa5968bd2c3d58547f78bba Mon Sep 17 00:00:00 2001 From: wejoncy Date: Wed, 25 Sep 2024 10:51:40 +0800 Subject: [PATCH 13/39] Update onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc Co-authored-by: Scott McKay --- .../core/providers/coreml/builders/impl/builder_utils.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc index a27895b6e37f7..d053fc5b9496d 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc @@ -214,8 +214,8 @@ void CopyDataToTensorValue(MILSpec::TensorValue& tensor_value, gsl::span< template <> void CopyDataToTensorValue(MILSpec::TensorValue& tensor_value, gsl::span data) { - const char* begin = (const char*)(data.data()); - const char* end = (const char*)(data.data()) + data.size() * sizeof(MLFloat16); + const char* begin = reinterpret_cast(data.data()); + const char* end = begin + (data.size() * sizeof(MLFloat16)); tensor_value.mutable_bytes()->mutable_values()->assign(begin, end); } From e9b2a427f2b1613a4cbf4a5cb7aaba3bb63238a1 Mon Sep 17 00:00:00 2001 From: wejoncy Date: Wed, 25 Sep 2024 10:51:47 +0800 Subject: [PATCH 14/39] Update onnxruntime/core/providers/coreml/builders/impl/builder_utils.h Co-authored-by: Scott McKay --- onnxruntime/core/providers/coreml/builders/impl/builder_utils.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h index f25936e25a17f..f38afc0ec181d 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h +++ b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h @@ -41,7 +41,7 @@ Status CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, const ONN // Copy the float array to a coreml weight void CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, gsl::span data); -// Copy the float array to a coreml weight +// Copy the MLFloat16 array to a coreml weight void CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, gsl::span data); // Copy the int32_t array to a coreml weight From 8fa2b488a2f1f1d577f994642293ed2430e6a698 Mon Sep 17 00:00:00 2001 From: wejoncy Date: Wed, 25 Sep 2024 10:53:23 +0800 Subject: [PATCH 15/39] Update onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc Co-authored-by: Scott McKay --- .../core/providers/coreml/builders/impl/gemm_op_builder.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc index 535f546840073..c470a69e8b935 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc @@ -139,7 +139,7 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N const auto N = transB ? b0 : b1; // we already checked it and dtype must be existed. #if defined(COREML_ENABLE_MLPROGRAM) - auto input_dtype = node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + auto input_dtype = a.TypeAsProto()->tensor_type().elem_type(); if (model_builder.CreateMLProgram()) { using namespace CoreML::Specification::MILSpec; From 154a39902e788d029b79bca573c8fb2420502234 Mon Sep 17 00:00:00 2001 From: wejoncy Date: Wed, 25 Sep 2024 10:53:49 +0800 Subject: [PATCH 16/39] Update onnxruntime/core/providers/coreml/builders/model_builder.h Co-authored-by: Scott McKay --- onnxruntime/core/providers/coreml/builders/model_builder.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/coreml/builders/model_builder.h b/onnxruntime/core/providers/coreml/builders/model_builder.h index f94521ac91979..1438f1b00dca4 100644 --- a/onnxruntime/core/providers/coreml/builders/model_builder.h +++ b/onnxruntime/core/providers/coreml/builders/model_builder.h @@ -112,7 +112,7 @@ class ModelBuilder { std::is_same_v || std::is_same_v, // add specialization in AddConstantImpl for new types if needed - "AddConstant currently supports float/MLFloat16, int64_t, std::string and bool."); + "AddConstant currently supports float, MLFloat16, int64_t, std::string and bool."); return AddConstantImpl(op_type, value_type, value, shape); } From b7c60786ffa02b232f31d00dd1f861cc1a860389 Mon Sep 17 00:00:00 2001 From: wejoncy Date: Wed, 25 Sep 2024 10:54:40 +0800 Subject: [PATCH 17/39] Update onnxruntime/test/providers/cpu/math/gemm_test.cc Co-authored-by: Scott McKay --- onnxruntime/test/providers/cpu/math/gemm_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/test/providers/cpu/math/gemm_test.cc b/onnxruntime/test/providers/cpu/math/gemm_test.cc index fd0b0bd5bc170..f9f1a829d4396 100644 --- a/onnxruntime/test/providers/cpu/math/gemm_test.cc +++ b/onnxruntime/test/providers/cpu/math/gemm_test.cc @@ -25,7 +25,7 @@ const constexpr auto run_with_tunable_op = &run_options; } // namespace -// Only CUDA and ROCM/CoreML kernel has float 16 support +// Only CUDA, ROCM and CoreML kernels have float 16 support TEST(GemmOpTest, GemmNoTrans_f16) { #ifdef USE_CUDA int min_cuda_architecture = 530; From fe3a3a3a6d05c2615b087f7c1cec9ef333746653 Mon Sep 17 00:00:00 2001 From: wejoncy Date: Tue, 24 Sep 2024 23:47:25 -0700 Subject: [PATCH 18/39] address comments && add tolerance --- .../coreml/builders/impl/base_op_builder.cc | 3 +- .../coreml/builders/impl/gemm_op_builder.cc | 86 ++++++++------- .../providers/coreml/builders/model_builder.h | 21 ---- .../core/providers/coreml/model/model.mm | 32 +++--- .../cpu/math/element_wise_ops_test.cc | 103 +++++++++--------- .../test/providers/cpu/math/gemm_test.cc | 68 ++++++++++-- .../test/providers/cpu/nn/conv_fp16_test.cc | 10 +- 7 files changed, 169 insertions(+), 154 deletions(-) diff --git a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc index 300e03ed1f39b..f185a80de3cbf 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc @@ -117,7 +117,8 @@ bool BaseOpBuilder::IsInputDtypeSupport(const Node& node, size_t idx, // only support MLProgram for FP16 #if defined(COREML_ENABLE_MLPROGRAM) - if (input_params.create_mlprogram && input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 && Float16Ops.count(node.OpType())) { + if (input_params.create_mlprogram && input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 && + Float16Ops.count(node.OpType())) { return true; } #endif diff --git a/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc index c470a69e8b935..71a4fe9b12035 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc @@ -71,34 +71,19 @@ void GemmOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Nod } // This is an internal function, requires input tensor to be 2d float/float16 tensor -// TODO, add support of other data types -// Template will make the binary size inflation, and which woundn't affect the runtime performance. +template static Status GetTensorDataTransposed(const ONNX_NAMESPACE::TensorProto& tensor, - std::vector& transposed_data, - int dtype = ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + std::vector& transposed_data) { Initializer unpacked_tensor(tensor); - const void* src_dataraw = unpacked_tensor.DataAsByteSpan().data(); + const auto src_data = unpacked_tensor.DataAsSpan(); const auto& tensor_shape = tensor.dims(); auto x_t = SafeInt(tensor_shape[0]); auto y_t = SafeInt(tensor_shape[1]); - size_t bytes_in_type = dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 ? sizeof(MLFloat16) : sizeof(float); - transposed_data.resize(x_t * y_t * bytes_in_type); - void* dst_data = transposed_data.data(); - if (dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { - MLFloat16* dst_ptr = (MLFloat16*)dst_data; - const MLFloat16* src_ptr = (const MLFloat16*)src_dataraw; - for (size_t x = 0; x < x_t; x++) { - for (size_t y = 0; y < y_t; y++) { - dst_ptr[y * x_t + x] = src_ptr[x * y_t + y]; - } - } - } else { - float* dst_ptr = (float*)dst_data; - const float* src_ptr = (const float*)src_dataraw; - for (size_t x = 0; x < x_t; x++) { - for (size_t y = 0; y < y_t; y++) { - dst_ptr[y * x_t + x] = src_ptr[x * y_t + y]; - } + transposed_data.resize(x_t * y_t); + + for (size_t x = 0; x < x_t; x++) { + for (size_t y = 0; y < y_t; y++) { + transposed_data[y * x_t + x] = src_data[x * y_t + y]; } } @@ -153,12 +138,19 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N if (transB) { AddOperationInput(*gemm_op, "weight", b.Name()); } else { - // transpose from {K, N} to {N, K} - std::vector weight_nk; // use bytes to store the type-erased data, could be any data-type std::vector weight_nk_shape = {N, K}; - ORT_RETURN_IF_ERROR(GetTensorDataTransposed(*b_initializer, weight_nk, input_dtype)); - AddOperationInput(*gemm_op, "weight", - model_builder.AddConstant(gemm_op->type(), b.Name() + "_t", weight_nk, input_dtype, weight_nk_shape)); + // transpose from {K, N} to {N, K} + if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + std::vector weight_nk; // use bytes to store the type-erased data, could be any data-type + ORT_RETURN_IF_ERROR(GetTensorDataTransposed(*b_initializer, weight_nk)); + AddOperationInput(*gemm_op, "weight", + model_builder.AddConstant(gemm_op->type(), b.Name() + "_t", weight_nk, weight_nk_shape)); + } else { // TensorProto_DataType_FLOAT16 + std::vector weight_nk; // use bytes to store the type-erased data, could be any data-type + ORT_RETURN_IF_ERROR(GetTensorDataTransposed(*b_initializer, weight_nk)); + AddOperationInput(*gemm_op, "weight", + model_builder.AddConstant(gemm_op->type(), b.Name() + "_t", weight_nk, weight_nk_shape)); + } } if (input_defs.size() == 3) { @@ -171,19 +163,30 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N AddOperationInput(*gemm_op, "bias", bias_arg.Name()); } else { Initializer unpacked_tensor(bias); - std::vector no_typed_bias_data; - gsl::span bias_data_span = unpacked_tensor.DataAsByteSpan(); - size_t bytes_in_type = input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT ? sizeof(float) : sizeof(MLFloat16); - // can use data as-is but need to adjust shape (inferred by AddConstant as {bias_data.size()}) - if (bias_data_span.size() == bytes_in_type) { - // expand scalar to N - no_typed_bias_data.resize(N * bytes_in_type, 0); - for (int64_t i = 0; i < N; i++) { - std::copy_n(bias_data_span.data(), bytes_in_type, no_typed_bias_data.data() + i * bytes_in_type); + std::string_view bias_data_name; + + if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + auto bias_data = unpacked_tensor.DataAsSpan(); + if (bias_data.size() == 1) { + // expand scalar to N + std::vector expanded_bias_data(N, bias_data[0]); + bias_data_name = model_builder.AddConstant(gemm_op->type(), "bias", expanded_bias_data); + } else { + // can use data as-is but need to adjust shape (inferred by AddConstant as {bias_data.size()}) + bias_data_name = model_builder.AddConstant(gemm_op->type(), "bias", bias_data); + } + } else { // TensorProto_DataType_FLOAT16 + auto bias_data = unpacked_tensor.DataAsSpan(); + if (bias_data.size() == 1) { + // expand scalar to N + std::vector expanded_bias_data(N, bias_data[0]); + bias_data_name = model_builder.AddConstant(gemm_op->type(), "bias", expanded_bias_data); + } else { + // can use data as-is but need to adjust shape (inferred by AddConstant as {bias_data.size()}) + bias_data_name = model_builder.AddConstant(gemm_op->type(), "bias", bias_data); } - bias_data_span = AsSpan(no_typed_bias_data); } - std::string_view bias_data_name = model_builder.AddConstant(gemm_op->type(), "bias", bias_data_span, input_dtype); + AddOperationInput(*gemm_op, "bias", bias_data_name); } } @@ -219,10 +222,9 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N if (transB) { ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_inner_product->mutable_weights(), *b_initializer)); } else { - std::vector b_transposed; + std::vector b_transposed; ORT_RETURN_IF_ERROR(GetTensorDataTransposed(*b_initializer, b_transposed)); - gsl::span b_transposed_f((const float*)(b_transposed.data()), b_transposed.size() / sizeof(float)); - CreateCoreMLWeight(*coreml_inner_product->mutable_weights(), b_transposed_f); + CreateCoreMLWeight(*coreml_inner_product->mutable_weights(), b_transposed); } if (is_gemm && input_defs.size() > 2) { diff --git a/onnxruntime/core/providers/coreml/builders/model_builder.h b/onnxruntime/core/providers/coreml/builders/model_builder.h index 1438f1b00dca4..b3dfec29872a2 100644 --- a/onnxruntime/core/providers/coreml/builders/model_builder.h +++ b/onnxruntime/core/providers/coreml/builders/model_builder.h @@ -116,27 +116,6 @@ class ModelBuilder { return AddConstantImpl(op_type, value_type, value, shape); } - std::string_view AddConstant(std::string_view op_type, std::string_view value_type, gsl::span value, - int input_dtype, std::optional> shape = std::nullopt) { - switch (input_dtype) { - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: { - gsl::span value_v((const float*)(value.data()), value.size() / sizeof(float)); - return AddConstant(op_type, value_type, value_v, shape); - } - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: { - gsl::span value_v((const MLFloat16*)(value.data()), value.size() / sizeof(MLFloat16)); - return AddConstant(op_type, value_type, value_v, shape); - } - case ONNX_NAMESPACE::TensorProto_DataType_INT64: { - gsl::span value_v((const int64_t*)(value.data()), value.size() / sizeof(int64_t)); - return AddConstant(op_type, value_type, value_v, shape); - } - default: - ORT_ENFORCE(false, "unspported data type"); - break; - } - } - template std::string_view AddConstant(std::string_view op_type, std::string_view value_type, const std::vector& value, std::optional> shape = std::nullopt) { diff --git a/onnxruntime/core/providers/coreml/model/model.mm b/onnxruntime/core/providers/coreml/model/model.mm index 60c93aa601622..97e157d738371 100644 --- a/onnxruntime/core/providers/coreml/model/model.mm +++ b/onnxruntime/core/providers/coreml/model/model.mm @@ -178,6 +178,16 @@ Status CreateInputFeatureProvider(const std::unordered_map +void StrideCopy(const T* src_buffer, T* dst_buffer, size_t block_size, + size_t num_blocks, size_t src_stride, size_t dst_stride) { + for (size_t idx = 0; idx < num_blocks; ++idx) { + std::copy_n(src_buffer, block_size, dst_buffer); + src_buffer += src_stride; + dst_buffer += dst_stride; + } +} + Status CopyMLMultiArrayBuffer(const void* mlmultiarray_buffer, void* tensor_buffer, const MLMultiArray* array, const int64_t num_blocks, const int64_t block_size, const int64_t stride, @@ -200,37 +210,21 @@ Status CopyMLMultiArrayBuffer(const void* mlmultiarray_buffer, void* tensor_buff case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: { const auto* src_buffer = static_cast(mlmultiarray_buffer); auto* dst_buffer = static_cast(tensor_buffer); - const auto block_byte_size = block_size * sizeof(float); + StrideCopy(src_buffer, dst_buffer, block_size, num_blocks, stride, block_size); - for (int64_t idx = 0; idx < num_blocks; ++idx) { - memcpy(dst_buffer, src_buffer, block_byte_size); - src_buffer += stride; - dst_buffer += block_size; - } break; } case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: { const auto* src_buffer = static_cast(mlmultiarray_buffer); auto* dst_buffer = static_cast(tensor_buffer); - const auto block_byte_size = block_size * sizeof(uint16_t); + StrideCopy(src_buffer, dst_buffer, block_size, num_blocks, stride, block_size); - for (int64_t idx = 0; idx < num_blocks; ++idx) { - memcpy(dst_buffer, src_buffer, block_byte_size); - src_buffer += stride; - dst_buffer += block_size; - } break; } case ONNX_NAMESPACE::TensorProto_DataType_INT32: { const auto* src_buffer = static_cast(mlmultiarray_buffer); auto* dst_buffer = static_cast(tensor_buffer); - const auto block_byte_size = block_size * sizeof(int32_t); - - for (int64_t idx = 0; idx < num_blocks; ++idx) { - memcpy(dst_buffer, src_buffer, block_byte_size); - src_buffer += stride; - dst_buffer += block_size; - } + StrideCopy(src_buffer, dst_buffer, block_size, num_blocks, stride, block_size); break; } diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index d030fc2ec85e9..83c057f204871 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -22,87 +22,86 @@ std::vector MakeMLFloat16(const std::initializer_list& input) return output; } -void TestBinaryFloat16([[maybe_unused]] const char* op_name, - [[maybe_unused]] const std::vector& lhs_dim, - [[maybe_unused]] const std::initializer_list& lhs_values, - [[maybe_unused]] const std::vector& rhs_dim, - [[maybe_unused]] const std::initializer_list& rhs_values, - [[maybe_unused]] const std::vector& out_dim, - [[maybe_unused]] const std::initializer_list& out_values, - [[maybe_unused]] bool enable_bf16 = true) { -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(COREML_ENABLE_MLPROGRAM) - { - OpTester tester(op_name, 14); - tester.AddInput("A", lhs_dim, MakeMLFloat16(lhs_values)); - tester.AddInput("B", rhs_dim, MakeMLFloat16(rhs_values)); - tester.AddOutput("C", out_dim, MakeMLFloat16(out_values)); - std::vector> execution_providers; +void TestBinaryFloat16(const char* op_name, + const std::vector& lhs_dim, + const std::initializer_list& lhs_values, + const std::vector& rhs_dim, + const std::initializer_list& rhs_values, + const std::vector& out_dim, + const std::initializer_list& out_values, + bool enable_bf16 = true) { + std::vector> execution_providers; #ifdef COREML_ENABLE_MLPROGRAM - execution_providers.push_back(DefaultCoreMLExecutionProvider(true)); + execution_providers.push_back(DefaultCoreMLExecutionProvider(true)); #elif USE_CUDA - execution_providers.push_back(DefaultCudaExecutionProvider()); + execution_providers.push_back(DefaultCudaExecutionProvider()); #elif USE_ROCM - execution_providers.push_back(DefaultRocmExecutionProvider()); + execution_providers.push_back(DefaultRocmExecutionProvider()); #endif + if (execution_providers.size() > 0) { + OpTester tester(op_name, 14); + tester.AddInput("A", lhs_dim, MakeMLFloat16(lhs_values)); + tester.AddInput("B", rhs_dim, MakeMLFloat16(rhs_values)); + tester.AddOutput("C", out_dim, MakeMLFloat16(out_values)); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } + execution_providers.clear(); + +#ifdef USE_CUDA + execution_providers.push_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + execution_providers.push_back(DefaultRocmExecutionProvider()); #endif -#if defined(USE_CUDA) || defined(USE_ROCM) - if (enable_bf16) { + if (enable_bf16 && execution_providers.size() > 0) { OpTester tester(op_name, 14); tester.AddInput("A", lhs_dim, MakeBFloat16(lhs_values)); tester.AddInput("B", rhs_dim, MakeBFloat16(rhs_values)); tester.AddOutput("C", out_dim, MakeBFloat16(out_values)); - std::vector> execution_providers; -#ifdef USE_CUDA - execution_providers.push_back(DefaultCudaExecutionProvider()); -#elif USE_ROCM - execution_providers.push_back(DefaultRocmExecutionProvider()); -#endif + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } -#endif } -void TestUnaryFloat16([[maybe_unused]] const char* op_name, - [[maybe_unused]] const std::vector& lhs_dim, - [[maybe_unused]] const std::initializer_list& lhs_values, - [[maybe_unused]] const std::vector& out_dim, - [[maybe_unused]] const std::initializer_list& out_values, - [[maybe_unused]] int opset = 14, - [[maybe_unused]] bool run_bf16 = true) { -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(COREML_ENABLE_MLPROGRAM) - { - OpTester tester(op_name, opset); - tester.AddInput("A", lhs_dim, MakeMLFloat16(lhs_values)); - tester.AddOutput("C", out_dim, MakeMLFloat16(out_values)); - std::vector> execution_providers; +void TestUnaryFloat16(const char* op_name, + const std::vector& lhs_dim, + const std::initializer_list& lhs_values, + const std::vector& out_dim, + const std::initializer_list& out_values, + int opset = 14, + bool run_bf16 = true) { + std::vector> execution_providers; #ifdef COREML_ENABLE_MLPROGRAM - execution_providers.push_back(DefaultCoreMLExecutionProvider(true)); + execution_providers.push_back(DefaultCoreMLExecutionProvider(true)); #elif USE_CUDA - execution_providers.push_back(DefaultCudaExecutionProvider()); + execution_providers.push_back(DefaultCudaExecutionProvider()); #elif USE_ROCM - execution_providers.push_back(DefaultRocmExecutionProvider()); + execution_providers.push_back(DefaultRocmExecutionProvider()); #endif + if (execution_providers.size() > 0) { + OpTester tester(op_name, opset); + tester.AddInput("A", lhs_dim, MakeMLFloat16(lhs_values)); + tester.AddOutput("C", out_dim, MakeMLFloat16(out_values)); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } + + execution_providers.clear(); + +#ifdef USE_CUDA + execution_providers.push_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + execution_providers.push_back(DefaultRocmExecutionProvider()); #endif -#if defined(USE_CUDA) || defined(USE_ROCM) - if (run_bf16) { + if (run_bf16 && execution_providers.size() > 0) { OpTester tester(op_name, opset); tester.AddInput("A", lhs_dim, MakeBFloat16(lhs_values)); tester.AddOutput("C", out_dim, MakeBFloat16(out_values)); - std::vector> execution_providers; -#ifdef USE_CUDA - execution_providers.push_back(DefaultCudaExecutionProvider()); -#elif USE_ROCM - execution_providers.push_back(DefaultRocmExecutionProvider()); -#endif + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } -#endif } void TestBFloat16(const char* op_name, const std::vector& lhs_dim, diff --git a/onnxruntime/test/providers/cpu/math/gemm_test.cc b/onnxruntime/test/providers/cpu/math/gemm_test.cc index f9f1a829d4396..c21e353ca2fbb 100644 --- a/onnxruntime/test/providers/cpu/math/gemm_test.cc +++ b/onnxruntime/test/providers/cpu/math/gemm_test.cc @@ -37,21 +37,23 @@ TEST(GemmOpTest, GemmNoTrans_f16) { std::vector A{1.0f, 2.0f, 3.0f, 4.0f, -1.0f, -2.0f, -3.0f, -4.0f}; - std::vector B(12, 1.0f); - std::vector C(6, 1.0f); - std::vector Y{11.0f, 11.0f, 11.0f, - -9.0f, -9.0f, -9.0f}; + std::vector B = {0.5f, 2.1f, 1.2f, -0.3f, -1.2f, 0.2f, 1.0f, -2.1f, 1.3f, 4.1f, 1.3f, -8.1f}; + std::vector C = {0.5f, 2.1f, 1.2f, -0.3f, -1.2f, 0.2f}; std::vector f_A(8); std::vector f_B(12); - std::vector f_C(6); - std::vector f_Y(6); ConvertFloatToMLFloat16(A.data(), f_A.data(), 8); ConvertFloatToMLFloat16(B.data(), f_B.data(), 12); - ConvertFloatToMLFloat16(C.data(), f_C.data(), 6); - ConvertFloatToMLFloat16(Y.data(), f_Y.data(), 6); { + // bias has same shape as output + std::vector f_Y(6); + std::vector Y{19.8f, 0.7f, -25.7f, -19.6f, 0.2f, 27.1f}; + ConvertFloatToMLFloat16(Y.data(), f_Y.data(), 6); + + std::vector f_C(6); + ConvertFloatToMLFloat16(C.data(), f_C.data(), 6); + OpTester test("Gemm", 13); test.AddAttribute("transA", (int64_t)0); @@ -59,15 +61,22 @@ TEST(GemmOpTest, GemmNoTrans_f16) { test.AddAttribute("alpha", 1.0f); test.AddAttribute("beta", 1.0f); test.AddInput("A", {2, 4}, f_A); - test.AddInput("B", {4, 3}, f_B); test.AddInput("C", {2, 3}, f_C); test.AddOutput("Y", {2, 3}, f_Y); + test.SetOutputTolerance(0.005f); test.ConfigExcludeEps({kTensorrtExecutionProvider}) // TensorRT: fp16 is not supported .Config(run_with_tunable_op) .RunWithConfig(); } { + // bias has shape {1, output_features} + std::vector f_Y(6); + std::vector Y{19.8f, 0.7f, -25.7f, -18.8f, 3.5f, 28.1f}; + ConvertFloatToMLFloat16(Y.data(), f_Y.data(), 6); + + std::vector f_C(3); + ConvertFloatToMLFloat16(C.data(), f_C.data(), 3); // CoreML program require B/C are constant OpTester test("Gemm", 13); @@ -77,14 +86,21 @@ TEST(GemmOpTest, GemmNoTrans_f16) { test.AddAttribute("beta", 1.0f); test.AddInput("A", {2, 4}, f_A); test.AddInput("B", {4, 3}, f_B, true); - f_C.resize(3); test.AddInput("C", {3}, f_C, true); test.AddOutput("Y", {2, 3}, f_Y); + test.SetOutputTolerance(0.005f); test.ConfigExcludeEps({kTensorrtExecutionProvider}) // TensorRT: fp16 is not supported .Config(run_with_tunable_op) .RunWithConfig(); } { + // bias is a scalar + std::vector f_Y(6); + std::vector Y{19.8f, -0.9f, -26.4f, -18.8f, 1.9f, 27.4f}; + ConvertFloatToMLFloat16(Y.data(), f_Y.data(), 6); + + std::vector f_C(1); + ConvertFloatToMLFloat16(C.data(), f_C.data(), 1); OpTester test("Gemm", 13); test.AddAttribute("transA", (int64_t)0); @@ -93,14 +109,42 @@ TEST(GemmOpTest, GemmNoTrans_f16) { test.AddAttribute("beta", 1.0f); test.AddInput("A", {2, 4}, f_A); test.AddInput("B", {4, 3}, f_B, true); - f_C.resize(1); test.AddInput("C", {1}, f_C, true); test.AddOutput("Y", {2, 3}, f_Y); + test.SetOutputTolerance(0.005f); test.ConfigExcludeEps({kTensorrtExecutionProvider}) // TensorRT: fp16 is not supported .Config(run_with_tunable_op) .RunWithConfig(); } +} + +// Only CUDA, ROCM and CoreML kernels have float 16 support +TEST(GemmOpTest, GemmTransB_f16) { +#ifdef USE_CUDA + int min_cuda_architecture = 530; + if (!HasCudaEnvironment(min_cuda_architecture)) { + LOGS_DEFAULT(WARNING) << "Hardware NOT support FP16"; + return; + } +#endif + + std::vector A{1.0f, 2.0f, 3.0f, 4.0f, + -1.0f, -2.0f, -3.0f, -4.0f}; + std::vector B = {0.5f, 2.1f, 1.2f, -0.3f, -1.2f, 0.2f, 1.0f, -2.1f, 1.3f, 4.1f, 1.3f, -8.1f}; + std::vector C = {0.5f, 2.1f, 1.2f, -0.3f, -1.2f, 0.2f}; + + std::vector f_A(8); + std::vector f_B(12); + ConvertFloatToMLFloat16(A.data(), f_A.data(), 8); + ConvertFloatToMLFloat16(B.data(), f_B.data(), 12); { + // bias is a scalar and transB is True + std::vector f_Y(6); + std::vector Y{7.6f, -5.7f, -18.5f, -6.6f, 6.7f, 19.5f}; + ConvertFloatToMLFloat16(Y.data(), f_Y.data(), 6); + + std::vector f_C(1); + ConvertFloatToMLFloat16(C.data(), f_C.data(), 1); OpTester test("Gemm", 13); test.AddAttribute("transA", (int64_t)0); @@ -109,9 +153,9 @@ TEST(GemmOpTest, GemmNoTrans_f16) { test.AddAttribute("beta", 1.0f); test.AddInput("A", {2, 4}, f_A); test.AddInput("B", {3, 4}, f_B, true); - f_C.resize(1); test.AddInput("C", {1}, f_C, true); test.AddOutput("Y", {2, 3}, f_Y); + test.SetOutputTolerance(0.005f); test.ConfigExcludeEps({kTensorrtExecutionProvider}) // TensorRT: fp16 is not supported .Config(run_with_tunable_op) .RunWithConfig(); diff --git a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc index 77716e5e370c4..d5844b5f63a42 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc @@ -37,12 +37,6 @@ void TestConvFp16Op(const ConvOpAndTestAttributes& attributes, OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess, const std::string& err_str = "", int opset = 11) { -#if !defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) - // a `return` after tester will make binary crash - if (!attributes.activation.empty()) { - return; - } -#endif std::unique_ptr tester; if (!attributes.activation.empty()) { tester = std::make_unique("NhwcFusedConv", 1, onnxruntime::kMSDomain); @@ -91,7 +85,9 @@ void TestConvFp16Op(const ConvOpAndTestAttributes& attributes, if (!weight_is_initializer || attributes.auto_pad == "SAME_UPPER" || attributes.auto_pad == "SAME_LOWER") { excluded_providers.insert(kQnnExecutionProvider); } - + if (!attributes.activation.empty()) { + excluded_providers.insert(kCoreMLExecutionProvider); + } tester->Run(expect_result, err_str, excluded_providers); } From 43f6e19e8859be17d3dabcd8db2567d3140f0492 Mon Sep 17 00:00:00 2001 From: wejoncy Date: Wed, 25 Sep 2024 20:54:26 -0700 Subject: [PATCH 19/39] add comments to explain convfp16 test --- onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc index d5844b5f63a42..97136ead9abf9 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc @@ -28,6 +28,18 @@ struct ConvOpAndTestAttributes { vector activation_parameters = {}; }; + +/* +Please notice that, we have two predefined macros in the head of the file +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM). +When we have these two macro defines, this UT will turn into green light and work. + +`NhwcFusedConv` in FP16 dtype is a contribe op and not well support by basic CPU ep. +Once your EP can satisfy all the conditions and capture the op, UT will crash as there +is no appropriate ep can handle this node. +As What CoreML did, if attributes has activation fused in, we should exclude CoreML ep +to let the test pass. +*/ void TestConvFp16Op(const ConvOpAndTestAttributes& attributes, const vector>& inputs, const vector>& input_shapes, From 7ef5e1edf7a54bab202e732daaa4bd1eecf91b01 Mon Sep 17 00:00:00 2001 From: wejoncy Date: Wed, 25 Sep 2024 21:32:31 -0700 Subject: [PATCH 20/39] format --- onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc index 97136ead9abf9..b0981ff6ea238 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc @@ -28,7 +28,6 @@ struct ConvOpAndTestAttributes { vector activation_parameters = {}; }; - /* Please notice that, we have two predefined macros in the head of the file #if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM). From 97e87ada9c2f23fdc8c5c807b21acdcdf25085c7 Mon Sep 17 00:00:00 2001 From: wejoncy Date: Wed, 25 Sep 2024 22:43:07 -0700 Subject: [PATCH 21/39] add curly brace for code block --- .../cpu/math/element_wise_ops_test.cc | 82 ++++++++++--------- 1 file changed, 44 insertions(+), 38 deletions(-) diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index 83c057f204871..ffb7d92a794d4 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -30,37 +30,40 @@ void TestBinaryFloat16(const char* op_name, const std::vector& out_dim, const std::initializer_list& out_values, bool enable_bf16 = true) { - std::vector> execution_providers; + { + std::vector> execution_providers; #ifdef COREML_ENABLE_MLPROGRAM - execution_providers.push_back(DefaultCoreMLExecutionProvider(true)); + execution_providers.push_back(DefaultCoreMLExecutionProvider(true)); #elif USE_CUDA - execution_providers.push_back(DefaultCudaExecutionProvider()); + execution_providers.push_back(DefaultCudaExecutionProvider()); #elif USE_ROCM - execution_providers.push_back(DefaultRocmExecutionProvider()); + execution_providers.push_back(DefaultRocmExecutionProvider()); #endif - if (execution_providers.size() > 0) { - OpTester tester(op_name, 14); - tester.AddInput("A", lhs_dim, MakeMLFloat16(lhs_values)); - tester.AddInput("B", rhs_dim, MakeMLFloat16(rhs_values)); - tester.AddOutput("C", out_dim, MakeMLFloat16(out_values)); + if (execution_providers.size() > 0) { + OpTester tester(op_name, 14); + tester.AddInput("A", lhs_dim, MakeMLFloat16(lhs_values)); + tester.AddInput("B", rhs_dim, MakeMLFloat16(rhs_values)); + tester.AddOutput("C", out_dim, MakeMLFloat16(out_values)); - tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } } - execution_providers.clear(); - + { + std::vector> execution_providers; #ifdef USE_CUDA - execution_providers.push_back(DefaultCudaExecutionProvider()); + execution_providers.push_back(DefaultCudaExecutionProvider()); #elif USE_ROCM - execution_providers.push_back(DefaultRocmExecutionProvider()); + execution_providers.push_back(DefaultRocmExecutionProvider()); #endif - if (enable_bf16 && execution_providers.size() > 0) { - OpTester tester(op_name, 14); - tester.AddInput("A", lhs_dim, MakeBFloat16(lhs_values)); - tester.AddInput("B", rhs_dim, MakeBFloat16(rhs_values)); - tester.AddOutput("C", out_dim, MakeBFloat16(out_values)); + if (enable_bf16 && execution_providers.size() > 0) { + OpTester tester(op_name, 14); + tester.AddInput("A", lhs_dim, MakeBFloat16(lhs_values)); + tester.AddInput("B", rhs_dim, MakeBFloat16(rhs_values)); + tester.AddOutput("C", out_dim, MakeBFloat16(out_values)); - tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } } } @@ -71,36 +74,39 @@ void TestUnaryFloat16(const char* op_name, const std::initializer_list& out_values, int opset = 14, bool run_bf16 = true) { - std::vector> execution_providers; + { + std::vector> execution_providers; #ifdef COREML_ENABLE_MLPROGRAM - execution_providers.push_back(DefaultCoreMLExecutionProvider(true)); + execution_providers.push_back(DefaultCoreMLExecutionProvider(true)); #elif USE_CUDA - execution_providers.push_back(DefaultCudaExecutionProvider()); + execution_providers.push_back(DefaultCudaExecutionProvider()); #elif USE_ROCM - execution_providers.push_back(DefaultRocmExecutionProvider()); + execution_providers.push_back(DefaultRocmExecutionProvider()); #endif - if (execution_providers.size() > 0) { - OpTester tester(op_name, opset); - tester.AddInput("A", lhs_dim, MakeMLFloat16(lhs_values)); - tester.AddOutput("C", out_dim, MakeMLFloat16(out_values)); + if (execution_providers.size() > 0) { + OpTester tester(op_name, opset); + tester.AddInput("A", lhs_dim, MakeMLFloat16(lhs_values)); + tester.AddOutput("C", out_dim, MakeMLFloat16(out_values)); - tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } } - execution_providers.clear(); - + { + std::vector> execution_providers; #ifdef USE_CUDA - execution_providers.push_back(DefaultCudaExecutionProvider()); + execution_providers.push_back(DefaultCudaExecutionProvider()); #elif USE_ROCM - execution_providers.push_back(DefaultRocmExecutionProvider()); + execution_providers.push_back(DefaultRocmExecutionProvider()); #endif - if (run_bf16 && execution_providers.size() > 0) { - OpTester tester(op_name, opset); - tester.AddInput("A", lhs_dim, MakeBFloat16(lhs_values)); - tester.AddOutput("C", out_dim, MakeBFloat16(out_values)); + if (run_bf16 && execution_providers.size() > 0) { + OpTester tester(op_name, opset); + tester.AddInput("A", lhs_dim, MakeBFloat16(lhs_values)); + tester.AddOutput("C", out_dim, MakeBFloat16(out_values)); - tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } } } From 48c98ec89ec7ba015e3a619f60e3438cce43e070 Mon Sep 17 00:00:00 2001 From: wejoncy Date: Wed, 25 Sep 2024 23:58:36 -0700 Subject: [PATCH 22/39] add conv fp16 with intilizer=true --- onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc index b0981ff6ea238..cbdd92ab3096e 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc @@ -96,7 +96,7 @@ void TestConvFp16Op(const ConvOpAndTestAttributes& attributes, if (!weight_is_initializer || attributes.auto_pad == "SAME_UPPER" || attributes.auto_pad == "SAME_LOWER") { excluded_providers.insert(kQnnExecutionProvider); } - if (!attributes.activation.empty()) { + if (!weight_is_initializer || !attributes.activation.empty()) { excluded_providers.insert(kCoreMLExecutionProvider); } tester->Run(expect_result, err_str, excluded_providers); @@ -1160,6 +1160,7 @@ TEST(ConvFp16Test, Pointwise_Relu) { MLFloat16(17.5f), MLFloat16(9.5f)}; TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); + TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); } TEST(ConvFp16Test, Conv2D_HardSigmoid) { @@ -1189,6 +1190,7 @@ TEST(ConvFp16Test, Conv2D_HardSigmoid) { MLFloat16(1.0f), MLFloat16(0.0f), MLFloat16(1.0f), MLFloat16(0.0f)}; TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); + TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); } TEST(ConvFp16Test, Conv2D_Bias_Z_Relu) { @@ -1218,6 +1220,7 @@ TEST(ConvFp16Test, Conv2D_Bias_Z_Relu) { vector Z_shape = {1, 2, 2, 2}; auto expected_vals = {MLFloat16(12.0f), MLFloat16(11.0f), MLFloat16(17.0f), MLFloat16(15.0f), MLFloat16(25.0f), MLFloat16(23.0f), MLFloat16(29.0f), MLFloat16(28.0f)}; TestConvFp16Op(attrs, {X, W, B, Z}, {X_shape, W_shape, B_shape, Z_shape}, expected_vals, Y_shape); + TestConvFp16Op(attrs, {X, W, B, Z}, {X_shape, W_shape, B_shape, Z_shape}, expected_vals, Y_shape, true); } #endif // CONTRIB_OPS From c9e75c944407d9c12e9116c3808ff19d52cd320b Mon Sep 17 00:00:00 2001 From: wejoncy Date: Thu, 26 Sep 2024 01:50:17 -0700 Subject: [PATCH 23/39] qnn convfp16 --- onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc index cbdd92ab3096e..15103ecddafa9 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc @@ -3,7 +3,7 @@ #include "core/mlas/inc/mlas.h" -#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM) +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM) || defined(USE_QNN) #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" From a8e54856fa36a8d1fa954d3a1f53e3409f0b6a7b Mon Sep 17 00:00:00 2001 From: wejoncy Date: Thu, 26 Sep 2024 07:12:45 -0700 Subject: [PATCH 24/39] fix qnn --- onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc index 15103ecddafa9..285f9ad05fef5 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc @@ -29,12 +29,12 @@ struct ConvOpAndTestAttributes { }; /* -Please notice that, we have two predefined macros in the head of the file -#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM). +Please notice that, we have predefined macros in the head of the file +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM)|| defined(USE_QNN). When we have these two macro defines, this UT will turn into green light and work. `NhwcFusedConv` in FP16 dtype is a contribe op and not well support by basic CPU ep. -Once your EP can satisfy all the conditions and capture the op, UT will crash as there +Once your EP can't satisfy all the conditions and capture the op, UT will crash as there is no appropriate ep can handle this node. As What CoreML did, if attributes has activation fused in, we should exclude CoreML ep to let the test pass. @@ -93,7 +93,9 @@ void TestConvFp16Op(const ConvOpAndTestAttributes& attributes, // Disable TensorRT because weight as input is not supported excluded_providers.insert(kTensorrtExecutionProvider); // QNN have issue with dynamic weight, auto pad with SAME_UPPER, SAME_LOWER - if (!weight_is_initializer || attributes.auto_pad == "SAME_UPPER" || attributes.auto_pad == "SAME_LOWER") { + if (!weight_is_initializer || attributes.auto_pad == "SAME_UPPER" || + attributes.auto_pad == "SAME_LOWER" || + !attributes.activation.empty()) { excluded_providers.insert(kQnnExecutionProvider); } if (!weight_is_initializer || !attributes.activation.empty()) { From 749940d1badd8348dbb960e75b4fff15290c302d Mon Sep 17 00:00:00 2001 From: wejoncy Date: Thu, 26 Sep 2024 23:39:06 -0700 Subject: [PATCH 25/39] add UT for the other ops --- .../builders/impl/activation_op_builder.cc | 8 +- .../coreml/builders/impl/clip_op_builder.cc | 24 +- .../builders/impl/gridsample_op_builder.cc | 9 +- .../cpu/activation/activation_op_test.cc | 6 +- .../cpu/activation/activation_op_test.h | 2 - .../test/providers/cpu/math/clip_test.cc | 32 +- .../providers/cpu/tensor/grid_sample_test.cc | 679 +++++++++--------- .../cpu/tensor/grid_sample_test_gen.py | 60 +- .../providers/cpu/tensor/resize_op_test.cc | 27 +- 9 files changed, 453 insertions(+), 394 deletions(-) diff --git a/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc index c8670cd546253..5389eb5ab7e95 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc @@ -104,7 +104,13 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, if (add_alpha) { NodeAttrHelper helper(node); const auto alpha = helper.Get("alpha", 0.01f); - AddOperationInput(*op, "alpha", model_builder.AddScalarConstant(op->type(), "alpha", alpha)); + + auto input_dtype = node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + AddOperationInput(*op, "alpha", model_builder.AddScalarConstant(op->type(), "alpha", alpha)); + } else { + AddOperationInput(*op, "alpha", model_builder.AddScalarConstant(op->type(), "alpha", MLFloat16(alpha))); + } } AddOperationOutput(*op, *node.OutputDefs()[0]); diff --git a/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc index 41f4041ef1181..53a0bd405e8fc 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc @@ -60,6 +60,8 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto& output_name = output.Name(); float min, max; ORT_RETURN_IF_NOT(GetClipMinMax(model_builder.GetGraphViewer(), node, min, max, logger), "GetClipMinMax failed"); + // we already checked it and dtype must be existed. + auto input_dtype = node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); bool has_min = min != std::numeric_limits::lowest(); bool has_max = max != std::numeric_limits::max(); @@ -94,19 +96,31 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, // if min and max were attributes we need to add initializers. otherwise we use the existing inputs const bool min_max_attribs = node.SinceVersion() < 11; - std::string_view min_name = min_max_attribs ? model_builder.AddScalarConstant(clip_op.type(), "min", min) - : node.InputDefs()[1]->Name(); + std::string_view min_name; + if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + min_name = min_max_attribs ? model_builder.AddScalarConstant(clip_op.type(), "min", min) + : node.InputDefs()[1]->Name(); + } else { + min_name = min_max_attribs ? model_builder.AddScalarConstant(clip_op.type(), "min", MLFloat16(min)) + : node.InputDefs()[1]->Name(); + } AddOperationInput(clip_op, "alpha", min_name); if (has_max) { - std::string_view max_name = min_max_attribs ? model_builder.AddScalarConstant(clip_op.type(), "max", max) - : node.InputDefs()[2]->Name(); + std::string_view max_name; + if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + max_name = min_max_attribs ? model_builder.AddScalarConstant(clip_op.type(), "max", max) + : node.InputDefs()[2]->Name(); + } else { + max_name = min_max_attribs ? model_builder.AddScalarConstant(clip_op.type(), "max", MLFloat16(max)) + : node.InputDefs()[2]->Name(); + } AddOperationInput(clip_op, "beta", max_name); } } } - + std::cout << "3444444444444444444444444444444444444444444\n\n"; AddOperationOutput(*op, output); model_builder.AddOperation(std::move(op)); } else diff --git a/onnxruntime/core/providers/coreml/builders/impl/gridsample_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/gridsample_op_builder.cc index 9caec290ea5a2..6dcf14c16f111 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/gridsample_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/gridsample_op_builder.cc @@ -49,6 +49,9 @@ Status GridSampleOpBuilder::AddToModelBuilderImpl([[maybe_unused]] ModelBuilder& const auto input_defs = node.InputDefs(); const auto output_defs = node.OutputDefs(); + // we already checked it and dtype must be existed. + auto input_dtype = input_defs[0]->TypeAsProto()->tensor_type().elem_type(); + NodeAttrHelper helper(node); std::string mode{GetMode(helper)}; // need a std::string for use in AddScalarConstant std::string padding_mode = helper.Get("padding_mode", "zeros"); @@ -65,7 +68,11 @@ Status GridSampleOpBuilder::AddToModelBuilderImpl([[maybe_unused]] ModelBuilder& AddOperationInput(*op, "coordinates", input_defs[1]->Name()); AddOperationInput(*op, "sampling_mode", model_builder.AddScalarConstant(op->type(), "sampling_mode", mode)); AddOperationInput(*op, "padding_mode", model_builder.AddScalarConstant(op->type(), "padding_mode", padding_mode)); - AddOperationInput(*op, "padding_value", model_builder.AddScalarConstant(op->type(), "padding_value", 0.0f)); + if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + AddOperationInput(*op, "padding_value", model_builder.AddScalarConstant(op->type(), "padding_value", 0.0f)); + } else { + AddOperationInput(*op, "padding_value", model_builder.AddScalarConstant(op->type(), "padding_value", MLFloat16(0.0f))); + } AddOperationInput(*op, "coordinates_mode", model_builder.AddScalarConstant(op->type(), "coordinates_mode", coordinates_mode)); AddOperationInput(*op, "align_corners", model_builder.AddScalarConstant(op->type(), "align_corners", align_corners)); diff --git a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc index d2e883331acd4..395f17ad59a9d 100644 --- a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc +++ b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc @@ -125,7 +125,7 @@ TEST_F(ActivationOpTest, Relu) { {}, {}, /*is_tensorrt_supported=*/false, /*opset_version= */ 14); -#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(USE_QNN) || defined(COREML_ENABLE_MLPROGRAM) TestActivationOp( "Relu", input_values_fp16, @@ -139,7 +139,7 @@ TEST_F(ActivationOpTest, Relu) { #endif // MLAS_F16VEC_INTRINSICS_SUPPORTED } -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(COREML_ENABLE_MLPROGRAM) TEST_F(ActivationOpTest, Sigmoid_fp16) { #ifdef USE_CUDA int min_cuda_architecture = 530; @@ -413,7 +413,7 @@ TEST_F(ActivationOpTest, LeakyRelu) { {{"alpha", alpha}}, {}); } -#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM) TEST_F(ActivationOpTest, LeakyRelu_fp16) { OpTester test("LeakyRelu", 11); float alpha = 0.01f; // oneDNN set alpha equal to 0.01 diff --git a/onnxruntime/test/providers/cpu/activation/activation_op_test.h b/onnxruntime/test/providers/cpu/activation/activation_op_test.h index 409409f56c51c..8ca0f6d845a09 100644 --- a/onnxruntime/test/providers/cpu/activation/activation_op_test.h +++ b/onnxruntime/test/providers/cpu/activation/activation_op_test.h @@ -90,7 +90,6 @@ class ActivationOpTest : public ::testing::Test { DBL_MAX, -DBL_MAX, std::numeric_limits::infinity()}}; // max, -max, inf std::vector> input_values_int8{{-1, -5, 0, 1, 5, 100, -100, // normal input values for activation std::numeric_limits::min(), std::numeric_limits::max()}}; // min, max -#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED std::vector> input_values_fp16{{MLFloat16(-1.0f), MLFloat16(-5.f), MLFloat16(), @@ -100,7 +99,6 @@ class ActivationOpTest : public ::testing::Test { MLFloat16(-100.f), MLFloat16(65504.f), MLFloat16(-65504.f)}}; -#endif // MLAS_F16VEC_INTRINSICS_SUPPORTED void SetUp() override { float low = -1.0f, high = 1.0f; diff --git a/onnxruntime/test/providers/cpu/math/clip_test.cc b/onnxruntime/test/providers/cpu/math/clip_test.cc index 9948a6cc8a681..382a5143b8969 100644 --- a/onnxruntime/test/providers/cpu/math/clip_test.cc +++ b/onnxruntime/test/providers/cpu/math/clip_test.cc @@ -120,21 +120,25 @@ TEST(MathOpTest, Clip_Default_uint64) { } TEST(MathOpTest, Clip_MLFloat16) { - OpTester test("Clip", 12); - - std::vector dims{3, 3}; - test.AddInput("X", dims, - {MLFloat16(-1.0f), MLFloat16(-2.0f), MLFloat16(-3.0f), - MLFloat16(-4.0f), MLFloat16(0.0f), MLFloat16(2.0f), - MLFloat16(4.0f), MLFloat16(6.0f), MLFloat16(8.0f)}); - test.AddInput("min", {}, {MLFloat16(0.0f)}); - test.AddInput("max", {}, {MLFloat16(6.0f)}); - test.AddOutput("Y", dims, - {MLFloat16(0.0f), MLFloat16(0.0f), MLFloat16(0.0f), - MLFloat16(0.0f), MLFloat16(0.0f), MLFloat16(2.0f), - MLFloat16(4.0f), MLFloat16(6.0f), MLFloat16(6.0f)}); + auto run_test = [](bool min_max_are_initializer) { + OpTester test("Clip", 12); - test.Run(); + std::vector dims{3, 3}; + test.AddInput("X", dims, + {MLFloat16(-1.0f), MLFloat16(-2.0f), MLFloat16(-3.0f), + MLFloat16(-4.0f), MLFloat16(0.0f), MLFloat16(2.0f), + MLFloat16(4.0f), MLFloat16(6.0f), MLFloat16(8.0f)}); + test.AddInput("min", {}, {MLFloat16(0.0f)}, min_max_are_initializer); + test.AddInput("max", {}, {MLFloat16(6.0f)}, min_max_are_initializer); + test.AddOutput("Y", dims, + {MLFloat16(0.0f), MLFloat16(0.0f), MLFloat16(0.0f), + MLFloat16(0.0f), MLFloat16(0.0f), MLFloat16(2.0f), + MLFloat16(4.0f), MLFloat16(6.0f), MLFloat16(6.0f)}); + + test.Run(); + }; + run_test(true); // coreml requires constant max/min + run_test(false); } TEST(MathOpTest, Clip_int32) { diff --git a/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc b/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc index 540dc6dee68fb..05cfb5c13d689 100644 --- a/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc @@ -40,963 +40,970 @@ void RunTests(T& test, std::vector>&& execut // DO NOT edit following tests. They are generated by: // onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py -TEST(GridSampleTest, test_grid_sample_16_4D_nearest_zeros_align_corners) { +template +class GridSampleTest : public ::testing::Test { +}; + +using GridSampleTestTypes = ::testing::Types; +TYPED_TEST_SUITE(GridSampleTest, GridSampleTestTypes); + +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_nearest_zeros_align_corners) { OpTester test("GridSample", 16); std::string mode = "nearest"; std::string padding_mode = "zeros"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-1.125840f, -1.152360f, -0.250579f, -0.433879f, 0.848710f, 0.692009f, -0.316013f, -2.115219f, 0.468096f, -0.157712f, 1.443660f, 0.266049f, 0.166455f, 0.874382f, -0.143474f, -0.111609f, 0.931827f, 1.259009f, 2.004981f, 0.053737f, 0.618057f, -0.412802f, -0.841065f, -2.316042f}; + std::initializer_list X_data{TypeParam(-1.125840f), TypeParam(-1.152360f), TypeParam(-0.250579f), TypeParam(-0.433879f), TypeParam(0.848710f), TypeParam(0.692009f), TypeParam(-0.316013f), TypeParam(-2.115219f), TypeParam(0.468096f), TypeParam(-0.157712f), TypeParam(1.443660f), TypeParam(0.266049f), TypeParam(0.166455f), TypeParam(0.874382f), TypeParam(-0.143474f), TypeParam(-0.111609f), TypeParam(0.931827f), TypeParam(1.259009f), TypeParam(2.004981f), TypeParam(0.053737f), TypeParam(0.618057f), TypeParam(-0.412802f), TypeParam(-0.841065f), TypeParam(-2.316042f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.063110f, -0.615220f, 0.203022f, -1.120434f, -0.867079f, -0.618636f, 0.757125f, 0.703586f, -0.532194f, -0.043299f, 0.767473f, 1.192960f, 0.476259f, 0.162111f, 0.804584f, -0.706563f, 0.223613f, -0.930367f, -0.831703f, -0.619900f, 0.542968f, 0.482592f, -0.710823f, 0.362529f}; + std::initializer_list Grid_data{TypeParam(0.063110f), TypeParam(-0.615220f), TypeParam(0.203022f), TypeParam(-1.120434f), TypeParam(-0.867079f), TypeParam(-0.618636f), TypeParam(0.757125f), TypeParam(0.703586f), TypeParam(-0.532194f), TypeParam(-0.043299f), TypeParam(0.767473f), TypeParam(1.192960f), TypeParam(0.476259f), TypeParam(0.162111f), TypeParam(0.804584f), TypeParam(-0.706563f), TypeParam(0.223613f), TypeParam(-0.930367f), TypeParam(-0.831703f), TypeParam(-0.619900f), TypeParam(0.542968f), TypeParam(0.482592f), TypeParam(-0.710823f), TypeParam(0.362529f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-1.152360f, -1.152360f, -1.125840f, 0.692009f, -0.250579f, 0.692009f, -2.115219f, -2.115219f, -0.316013f, 0.266049f, 0.468096f, 0.266049f, -0.111609f, 0.874382f, 0.874382f, 0.166455f, -0.111609f, -0.143474f, -0.412802f, 0.053737f, 0.053737f, 2.004981f, -0.412802f, 0.618057f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-1.152360f), TypeParam(-1.152360f), TypeParam(-1.125840f), TypeParam(0.692009f), TypeParam(-0.250579f), TypeParam(0.692009f), TypeParam(-2.115219f), TypeParam(-2.115219f), TypeParam(-0.316013f), TypeParam(0.266049f), TypeParam(0.468096f), TypeParam(0.266049f), TypeParam(-0.111609f), TypeParam(0.874382f), TypeParam(0.874382f), TypeParam(0.166455f), TypeParam(-0.111609f), TypeParam(-0.143474f), TypeParam(-0.412802f), TypeParam(0.053737f), TypeParam(0.053737f), TypeParam(2.004981f), TypeParam(-0.412802f), TypeParam(0.618057f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_16_4D_nearest_zeros_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_nearest_zeros_no_align_corners) { OpTester test("GridSample", 16); std::string mode = "nearest"; std::string padding_mode = "zeros"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-0.569248f, 0.919971f, 1.110816f, 1.289874f, -1.478174f, 2.567233f, -0.473120f, 0.335551f, -0.003304f, -0.534441f, 1.168688f, 0.394503f, 1.941462f, 0.791498f, -0.020252f, -0.437170f, -1.535287f, -0.412679f, 0.966303f, 1.624783f, -0.365619f, -1.302440f, 0.099403f, 0.441822f}; + std::initializer_list X_data{TypeParam(-0.569248f), TypeParam(0.919971f), TypeParam(1.110816f), TypeParam(1.289874f), TypeParam(-1.478174f), TypeParam(2.567233f), TypeParam(-0.473120f), TypeParam(0.335551f), TypeParam(-0.003304f), TypeParam(-0.534441f), TypeParam(1.168688f), TypeParam(0.394503f), TypeParam(1.941462f), TypeParam(0.791498f), TypeParam(-0.020252f), TypeParam(-0.437170f), TypeParam(-1.535287f), TypeParam(-0.412679f), TypeParam(0.966303f), TypeParam(1.624783f), TypeParam(-0.365619f), TypeParam(-1.302440f), TypeParam(0.099403f), TypeParam(0.441822f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{-1.143118f, -0.021569f, -0.903671f, -0.925628f, -0.066120f, 0.180174f, -0.491436f, 0.712053f, -0.730247f, 1.088844f, 0.822360f, -1.011940f, -0.298661f, 0.054147f, 0.175081f, 0.284609f, 0.470914f, 0.071880f, -0.585515f, 0.567827f, -1.151099f, -0.711248f, -0.300396f, -0.584536f}; + std::initializer_list Grid_data{TypeParam(-1.143118f), TypeParam(-0.021569f), TypeParam(-0.903671f), TypeParam(-0.925628f), TypeParam(-0.066120f), TypeParam(0.180174f), TypeParam(-0.491436f), TypeParam(0.712053f), TypeParam(-0.730247f), TypeParam(1.088844f), TypeParam(0.822360f), TypeParam(-1.011940f), TypeParam(-0.298661f), TypeParam(0.054147f), TypeParam(0.175081f), TypeParam(0.284609f), TypeParam(0.470914f), TypeParam(0.071880f), TypeParam(-0.585515f), TypeParam(0.567827f), TypeParam(-1.151099f), TypeParam(-0.711248f), TypeParam(-0.300396f), TypeParam(-0.584536f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{0.000000f, -0.569248f, 1.110816f, -1.478174f, 0.000000f, 0.000000f, 0.000000f, -0.473120f, -0.003304f, 1.168688f, 0.000000f, 0.000000f, -0.020252f, -0.437170f, -0.437170f, -1.535287f, 0.000000f, 1.941462f, -0.365619f, -1.302440f, -1.302440f, 0.099403f, 0.000000f, 0.966303f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(0.000000f), TypeParam(-0.569248f), TypeParam(1.110816f), TypeParam(-1.478174f), TypeParam(0.000000f), TypeParam(0.000000f), TypeParam(0.000000f), TypeParam(-0.473120f), TypeParam(-0.003304f), TypeParam(1.168688f), TypeParam(0.000000f), TypeParam(0.000000f), TypeParam(-0.020252f), TypeParam(-0.437170f), TypeParam(-0.437170f), TypeParam(-1.535287f), TypeParam(0.000000f), TypeParam(1.941462f), TypeParam(-0.365619f), TypeParam(-1.302440f), TypeParam(-1.302440f), TypeParam(0.099403f), TypeParam(0.000000f), TypeParam(0.966303f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_16_4D_nearest_border_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_nearest_border_align_corners) { OpTester test("GridSample", 16); std::string mode = "nearest"; std::string padding_mode = "border"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-0.883376f, -0.418913f, -0.804826f, 0.565610f, 0.610365f, 0.466884f, 1.950657f, -1.063099f, -0.829367f, -1.407257f, 1.626847f, 0.172273f, -1.611502f, -0.479448f, -0.143351f, -0.317295f, 0.573655f, 0.997931f, 0.543609f, 0.078804f, 0.862860f, -0.019490f, 0.991047f, -0.777735f}; + std::initializer_list X_data{TypeParam(-0.883376f), TypeParam(-0.418913f), TypeParam(-0.804826f), TypeParam(0.565610f), TypeParam(0.610365f), TypeParam(0.466884f), TypeParam(1.950657f), TypeParam(-1.063099f), TypeParam(-0.829367f), TypeParam(-1.407257f), TypeParam(1.626847f), TypeParam(0.172273f), TypeParam(-1.611502f), TypeParam(-0.479448f), TypeParam(-0.143351f), TypeParam(-0.317295f), TypeParam(0.573655f), TypeParam(0.997931f), TypeParam(0.543609f), TypeParam(0.078804f), TypeParam(0.862860f), TypeParam(-0.019490f), TypeParam(0.991047f), TypeParam(-0.777735f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{-1.080070f, -0.080985f, 1.055303f, -0.489470f, 1.083604f, 0.434584f, -1.082953f, 0.759237f, -0.138473f, -0.535688f, 0.959584f, -0.969714f, 0.128766f, -0.251242f, 0.856935f, 0.334973f, 0.576606f, 0.423791f, -0.288570f, -0.252367f, -0.988898f, 0.650213f, 0.952774f, 0.821070f}; + std::initializer_list Grid_data{TypeParam(-1.080070f), TypeParam(-0.080985f), TypeParam(1.055303f), TypeParam(-0.489470f), TypeParam(1.083604f), TypeParam(0.434584f), TypeParam(-1.082953f), TypeParam(0.759237f), TypeParam(-0.138473f), TypeParam(-0.535688f), TypeParam(0.959584f), TypeParam(-0.969714f), TypeParam(0.128766f), TypeParam(-0.251242f), TypeParam(0.856935f), TypeParam(0.334973f), TypeParam(0.576606f), TypeParam(0.423791f), TypeParam(-0.288570f), TypeParam(-0.252367f), TypeParam(-0.988898f), TypeParam(0.650213f), TypeParam(0.952774f), TypeParam(0.821070f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-0.804826f, 0.565610f, 0.565610f, 0.610365f, -0.883376f, -0.418913f, -0.829367f, -1.407257f, -1.407257f, 1.626847f, 1.950657f, -1.063099f, -0.317295f, -0.317295f, -0.317295f, -0.143351f, 0.573655f, 0.997931f, -0.019490f, -0.019490f, -0.019490f, 0.862860f, 0.991047f, -0.777735f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.804826f), TypeParam(0.565610f), TypeParam(0.565610f), TypeParam(0.610365f), TypeParam(-0.883376f), TypeParam(-0.418913f), TypeParam(-0.829367f), TypeParam(-1.407257f), TypeParam(-1.407257f), TypeParam(1.626847f), TypeParam(1.950657f), TypeParam(-1.063099f), TypeParam(-0.317295f), TypeParam(-0.317295f), TypeParam(-0.317295f), TypeParam(-0.143351f), TypeParam(0.573655f), TypeParam(0.997931f), TypeParam(-0.019490f), TypeParam(-0.019490f), TypeParam(-0.019490f), TypeParam(0.862860f), TypeParam(0.991047f), TypeParam(-0.777735f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_16_4D_nearest_border_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_nearest_border_no_align_corners) { OpTester test("GridSample", 16); std::string mode = "nearest"; std::string padding_mode = "border"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-0.559630f, 0.533472f, 0.406887f, 0.394587f, 0.171511f, 0.876045f, -0.287087f, 1.021640f, 0.438649f, -0.010704f, 1.338354f, -0.279405f, -0.551834f, -2.889061f, -1.509981f, 1.024115f, 0.195393f, -0.737109f, 1.700101f, 0.346216f, 0.971125f, 1.450250f, -0.051909f, -0.628431f}; + std::initializer_list X_data{TypeParam(-0.559630f), TypeParam(0.533472f), TypeParam(0.406887f), TypeParam(0.394587f), TypeParam(0.171511f), TypeParam(0.876045f), TypeParam(-0.287087f), TypeParam(1.021640f), TypeParam(0.438649f), TypeParam(-0.010704f), TypeParam(1.338354f), TypeParam(-0.279405f), TypeParam(-0.551834f), TypeParam(-2.889061f), TypeParam(-1.509981f), TypeParam(1.024115f), TypeParam(0.195393f), TypeParam(-0.737109f), TypeParam(1.700101f), TypeParam(0.346216f), TypeParam(0.971125f), TypeParam(1.450250f), TypeParam(-0.051909f), TypeParam(-0.628431f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.149807f, 1.074831f, 0.734055f, -0.758657f, 0.538205f, -0.848275f, -0.508590f, 0.352947f, 0.396231f, 0.900274f, -0.386299f, 0.001921f, 0.617788f, -1.160511f, 0.867577f, -0.992307f, 0.016539f, -0.204020f, -0.632008f, 0.158605f, 0.992302f, -0.350783f, -0.712433f, -0.443807f}; + std::initializer_list Grid_data{TypeParam(0.149807f), TypeParam(1.074831f), TypeParam(0.734055f), TypeParam(-0.758657f), TypeParam(0.538205f), TypeParam(-0.848275f), TypeParam(-0.508590f), TypeParam(0.352947f), TypeParam(0.396231f), TypeParam(0.900274f), TypeParam(-0.386299f), TypeParam(0.001921f), TypeParam(0.617788f), TypeParam(-1.160511f), TypeParam(0.867577f), TypeParam(-0.992307f), TypeParam(0.016539f), TypeParam(-0.204020f), TypeParam(-0.632008f), TypeParam(0.158605f), TypeParam(0.992302f), TypeParam(-0.350783f), TypeParam(-0.712433f), TypeParam(-0.443807f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{0.876045f, 0.533472f, 0.533472f, 0.171511f, 0.876045f, 0.406887f, -0.279405f, 1.021640f, 1.021640f, 1.338354f, -0.279405f, 0.438649f, -2.889061f, -2.889061f, 1.024115f, -1.509981f, -2.889061f, -0.551834f, 0.346216f, 0.346216f, 1.450250f, 0.971125f, 0.346216f, 1.700101f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(0.876045f), TypeParam(0.533472f), TypeParam(0.533472f), TypeParam(0.171511f), TypeParam(0.876045f), TypeParam(0.406887f), TypeParam(-0.279405f), TypeParam(1.021640f), TypeParam(1.021640f), TypeParam(1.338354f), TypeParam(-0.279405f), TypeParam(0.438649f), TypeParam(-2.889061f), TypeParam(-2.889061f), TypeParam(1.024115f), TypeParam(-1.509981f), TypeParam(-2.889061f), TypeParam(-0.551834f), TypeParam(0.346216f), TypeParam(0.346216f), TypeParam(1.450250f), TypeParam(0.971125f), TypeParam(0.346216f), TypeParam(1.700101f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_16_4D_nearest_reflection_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_nearest_reflection_align_corners) { OpTester test("GridSample", 16); std::string mode = "nearest"; std::string padding_mode = "reflection"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-0.039373f, -0.801472f, -0.495544f, -0.361514f, 0.585113f, -1.156007f, -0.143365f, -0.194741f, -0.906885f, -0.591838f, 0.150785f, -1.041149f, -0.720534f, -2.214754f, -0.683730f, 0.516358f, 0.792848f, 0.083228f, 0.422800f, -1.868747f, -1.105713f, 0.143731f, 0.583597f, 1.348155f}; + std::initializer_list X_data{TypeParam(-0.039373f), TypeParam(-0.801472f), TypeParam(-0.495544f), TypeParam(-0.361514f), TypeParam(0.585113f), TypeParam(-1.156007f), TypeParam(-0.143365f), TypeParam(-0.194741f), TypeParam(-0.906885f), TypeParam(-0.591838f), TypeParam(0.150785f), TypeParam(-1.041149f), TypeParam(-0.720534f), TypeParam(-2.214754f), TypeParam(-0.683730f), TypeParam(0.516358f), TypeParam(0.792848f), TypeParam(0.083228f), TypeParam(0.422800f), TypeParam(-1.868747f), TypeParam(-1.105713f), TypeParam(0.143731f), TypeParam(0.583597f), TypeParam(1.348155f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.829854f, -0.893309f, 0.491599f, -0.403504f, -0.578962f, 0.215574f, -0.623348f, 0.276486f, 0.235657f, -0.890987f, 0.199798f, 0.511115f, 0.474997f, -0.151054f, -0.983745f, -0.184985f, 0.416769f, -0.437853f, 0.455497f, 0.799155f, -0.626582f, 0.011834f, 0.496199f, 0.094053f}; + std::initializer_list Grid_data{TypeParam(0.829854f), TypeParam(-0.893309f), TypeParam(0.491599f), TypeParam(-0.403504f), TypeParam(-0.578962f), TypeParam(0.215574f), TypeParam(-0.623348f), TypeParam(0.276486f), TypeParam(0.235657f), TypeParam(-0.890987f), TypeParam(0.199798f), TypeParam(0.511115f), TypeParam(0.474997f), TypeParam(-0.151054f), TypeParam(-0.983745f), TypeParam(-0.184985f), TypeParam(0.416769f), TypeParam(-0.437853f), TypeParam(0.455497f), TypeParam(0.799155f), TypeParam(-0.626582f), TypeParam(0.011834f), TypeParam(0.496199f), TypeParam(0.094053f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-0.801472f, -0.361514f, -0.495544f, -0.495544f, -0.801472f, -1.156007f, -0.194741f, -0.591838f, -0.906885f, -0.906885f, -0.194741f, -1.041149f, 0.516358f, -0.683730f, 0.516358f, 0.083228f, -0.683730f, 0.516358f, 0.143731f, -1.105713f, 0.143731f, 1.348155f, -1.105713f, 0.143731f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.801472f), TypeParam(-0.361514f), TypeParam(-0.495544f), TypeParam(-0.495544f), TypeParam(-0.801472f), TypeParam(-1.156007f), TypeParam(-0.194741f), TypeParam(-0.591838f), TypeParam(-0.906885f), TypeParam(-0.906885f), TypeParam(-0.194741f), TypeParam(-1.041149f), TypeParam(0.516358f), TypeParam(-0.683730f), TypeParam(0.516358f), TypeParam(0.083228f), TypeParam(-0.683730f), TypeParam(0.516358f), TypeParam(0.143731f), TypeParam(-1.105713f), TypeParam(0.143731f), TypeParam(1.348155f), TypeParam(-1.105713f), TypeParam(0.143731f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_16_4D_nearest_reflection_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_nearest_reflection_no_align_corners) { OpTester test("GridSample", 16); std::string mode = "nearest"; std::string padding_mode = "reflection"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-0.129230f, -0.054595f, 0.408347f, 1.126366f, 1.935057f, 1.007685f, 1.004642f, -0.433520f, -0.562711f, -0.832754f, -1.395545f, -0.399295f, -0.309940f, -0.056062f, 0.517413f, -1.596237f, 0.356960f, -2.297482f, -0.871083f, -1.674028f, 0.563055f, -1.435067f, 0.719400f, -1.370747f}; + std::initializer_list X_data{TypeParam(-0.129230f), TypeParam(-0.054595f), TypeParam(0.408347f), TypeParam(1.126366f), TypeParam(1.935057f), TypeParam(1.007685f), TypeParam(1.004642f), TypeParam(-0.433520f), TypeParam(-0.562711f), TypeParam(-0.832754f), TypeParam(-1.395545f), TypeParam(-0.399295f), TypeParam(-0.309940f), TypeParam(-0.056062f), TypeParam(0.517413f), TypeParam(-1.596237f), TypeParam(0.356960f), TypeParam(-2.297482f), TypeParam(-0.871083f), TypeParam(-1.674028f), TypeParam(0.563055f), TypeParam(-1.435067f), TypeParam(0.719400f), TypeParam(-1.370747f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{-0.811910f, -1.183845f, -0.963667f, 0.947364f, 0.649243f, 1.125859f, 0.961345f, -1.071655f, -0.818917f, -0.193899f, -0.779319f, 0.833276f, -0.907209f, -0.585482f, -1.159310f, -0.681295f, 0.986973f, 0.982512f, 0.859005f, 0.926553f, 1.067024f, -0.307276f, 0.528003f, 1.069117f}; + std::initializer_list Grid_data{TypeParam(-0.811910f), TypeParam(-1.183845f), TypeParam(-0.963667f), TypeParam(0.947364f), TypeParam(0.649243f), TypeParam(1.125859f), TypeParam(0.961345f), TypeParam(-1.071655f), TypeParam(-0.818917f), TypeParam(-0.193899f), TypeParam(-0.779319f), TypeParam(0.833276f), TypeParam(-0.907209f), TypeParam(-0.585482f), TypeParam(-1.159310f), TypeParam(-0.681295f), TypeParam(0.986973f), TypeParam(0.982512f), TypeParam(0.859005f), TypeParam(0.926553f), TypeParam(1.067024f), TypeParam(-0.307276f), TypeParam(0.528003f), TypeParam(1.069117f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-0.129230f, 1.935057f, 1.007685f, -0.054595f, 0.408347f, 1.935057f, 1.004642f, -1.395545f, -0.399295f, -0.433520f, -0.562711f, -1.395545f, -0.309940f, -0.309940f, -2.297482f, -2.297482f, -1.596237f, -2.297482f, -0.871083f, -0.871083f, -1.370747f, -1.370747f, -1.435067f, -1.370747f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.129230f), TypeParam(1.935057f), TypeParam(1.007685f), TypeParam(-0.054595f), TypeParam(0.408347f), TypeParam(1.935057f), TypeParam(1.004642f), TypeParam(-1.395545f), TypeParam(-0.399295f), TypeParam(-0.433520f), TypeParam(-0.562711f), TypeParam(-1.395545f), TypeParam(-0.309940f), TypeParam(-0.309940f), TypeParam(-2.297482f), TypeParam(-2.297482f), TypeParam(-1.596237f), TypeParam(-2.297482f), TypeParam(-0.871083f), TypeParam(-0.871083f), TypeParam(-1.370747f), TypeParam(-1.370747f), TypeParam(-1.435067f), TypeParam(-1.370747f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_zeros_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_zeros_align_corners) { OpTester test("GridSample", 16); std::string mode = "bilinear"; std::string padding_mode = "zeros"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{0.294201f, 0.797322f, 1.264215f, 0.935492f, 0.545464f, -1.537389f, 0.312439f, 0.740060f, -0.575326f, -1.432532f, -0.666175f, 1.017438f, -2.241368f, 0.437349f, -0.555362f, -0.057943f, 0.658583f, 0.992938f, -0.206548f, -0.244841f, -0.380599f, 1.131112f, -0.090205f, -0.897900f}; + std::initializer_list X_data{TypeParam(0.294201f), TypeParam(0.797322f), TypeParam(1.264215f), TypeParam(0.935492f), TypeParam(0.545464f), TypeParam(-1.537389f), TypeParam(0.312439f), TypeParam(0.740060f), TypeParam(-0.575326f), TypeParam(-1.432532f), TypeParam(-0.666175f), TypeParam(1.017438f), TypeParam(-2.241368f), TypeParam(0.437349f), TypeParam(-0.555362f), TypeParam(-0.057943f), TypeParam(0.658583f), TypeParam(0.992938f), TypeParam(-0.206548f), TypeParam(-0.244841f), TypeParam(-0.380599f), TypeParam(1.131112f), TypeParam(-0.090205f), TypeParam(-0.897900f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.595248f, -1.096726f, -0.214731f, -0.891773f, -0.512023f, 0.432352f, -0.852156f, 0.446072f, 1.018534f, 0.078706f, -0.799785f, -0.429942f, 0.262037f, -0.914782f, 0.596172f, -1.089444f, -1.153552f, -1.165993f, -0.243436f, 0.806920f, -1.135775f, 0.997425f, -0.480027f, 0.351461f}; + std::initializer_list Grid_data{TypeParam(0.595248f), TypeParam(-1.096726f), TypeParam(-0.214731f), TypeParam(-0.891773f), TypeParam(-0.512023f), TypeParam(0.432352f), TypeParam(-0.852156f), TypeParam(0.446072f), TypeParam(1.018534f), TypeParam(0.078706f), TypeParam(-0.799785f), TypeParam(-0.429942f), TypeParam(0.262037f), TypeParam(-0.914782f), TypeParam(0.596172f), TypeParam(-1.089444f), TypeParam(-1.153552f), TypeParam(-1.165993f), TypeParam(-0.243436f), TypeParam(0.806920f), TypeParam(-1.135775f), TypeParam(0.997425f), TypeParam(-0.480027f), TypeParam(0.351461f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{0.628229f, 0.561377f, 0.688215f, 0.861459f, 0.733996f, 0.850061f, 0.590307f, 0.329661f, -0.555725f, -0.595435f, -1.228216f, -0.224152f, -0.524667f, -0.094262f, -1.725798f, 0.562584f, 0.610959f, -0.014286f, -0.162194f, -0.215901f, -0.159037f, -0.282404f, -0.084779f, -0.097448f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(0.628229f), TypeParam(0.561377f), TypeParam(0.688215f), TypeParam(0.861459f), TypeParam(0.733996f), TypeParam(0.850061f), TypeParam(0.590307f), TypeParam(0.329661f), TypeParam(-0.555725f), TypeParam(-0.595435f), TypeParam(-1.228216f), TypeParam(-0.224152f), TypeParam(-0.524667f), TypeParam(-0.094262f), TypeParam(-1.725798f), TypeParam(0.562584f), TypeParam(0.610959f), TypeParam(-0.014286f), TypeParam(-0.162194f), TypeParam(-0.215901f), TypeParam(-0.159037f), TypeParam(-0.282404f), TypeParam(-0.084779f), TypeParam(-0.097448f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_zeros_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_zeros_no_align_corners) { OpTester test("GridSample", 16); std::string mode = "bilinear"; std::string padding_mode = "zeros"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-1.199109f, -0.025686f, 1.802375f, -1.059653f, 3.402826f, -0.568670f, -0.475489f, 1.743163f, 1.060884f, -0.015953f, 1.275653f, 0.009457f, -0.369450f, 1.218198f, 0.255044f, 0.273993f, 1.404381f, 1.082878f, 0.788966f, -0.137615f, 0.122478f, -1.076701f, -0.650897f, -1.619658f}; + std::initializer_list X_data{TypeParam(-1.199109f), TypeParam(-0.025686f), TypeParam(1.802375f), TypeParam(-1.059653f), TypeParam(3.402826f), TypeParam(-0.568670f), TypeParam(-0.475489f), TypeParam(1.743163f), TypeParam(1.060884f), TypeParam(-0.015953f), TypeParam(1.275653f), TypeParam(0.009457f), TypeParam(-0.369450f), TypeParam(1.218198f), TypeParam(0.255044f), TypeParam(0.273993f), TypeParam(1.404381f), TypeParam(1.082878f), TypeParam(0.788966f), TypeParam(-0.137615f), TypeParam(0.122478f), TypeParam(-1.076701f), TypeParam(-0.650897f), TypeParam(-1.619658f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.038587f, -0.371014f, -0.260918f, 0.159481f, 0.594851f, -0.840708f, 1.007133f, -0.130476f, -1.005535f, -0.649269f, 1.061781f, 1.097433f, -1.111536f, 0.846358f, 0.601391f, 0.710302f, 1.015835f, -0.646740f, 0.378931f, 0.491080f, -0.354592f, 0.401584f, -0.345256f, 0.741914f}; + std::initializer_list Grid_data{TypeParam(0.038587f), TypeParam(-0.371014f), TypeParam(-0.260918f), TypeParam(0.159481f), TypeParam(0.594851f), TypeParam(-0.840708f), TypeParam(1.007133f), TypeParam(-0.130476f), TypeParam(-1.005535f), TypeParam(-0.649269f), TypeParam(1.061781f), TypeParam(1.097433f), TypeParam(-1.111536f), TypeParam(0.846358f), TypeParam(0.601391f), TypeParam(0.710302f), TypeParam(1.015835f), TypeParam(-0.646740f), TypeParam(0.378931f), TypeParam(0.491080f), TypeParam(-0.354592f), TypeParam(0.401584f), TypeParam(-0.345256f), TypeParam(0.741914f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-0.199899f, 1.437523f, -0.017180f, -0.422530f, -0.554188f, -0.088180f, 0.613663f, 0.843979f, 1.165913f, 0.161823f, -0.215288f, 0.001466f, 0.398506f, 0.909392f, 0.576145f, 0.897902f, 0.920312f, 1.201733f, -0.184698f, -1.360176f, -0.080218f, -1.352020f, -0.497572f, -0.710420f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.199899f), TypeParam(1.437523f), TypeParam(-0.017180f), TypeParam(-0.422530f), TypeParam(-0.554188f), TypeParam(-0.088180f), TypeParam(0.613663f), TypeParam(0.843979f), TypeParam(1.165913f), TypeParam(0.161823f), TypeParam(-0.215288f), TypeParam(0.001466f), TypeParam(0.398506f), TypeParam(0.909392f), TypeParam(0.576145f), TypeParam(0.897902f), TypeParam(0.920312f), TypeParam(1.201733f), TypeParam(-0.184698f), TypeParam(-1.360176f), TypeParam(-0.080218f), TypeParam(-1.352020f), TypeParam(-0.497572f), TypeParam(-0.710420f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_border_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_border_align_corners) { OpTester test("GridSample", 16); std::string mode = "bilinear"; std::string padding_mode = "border"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-0.546073f, -0.630178f, -0.634650f, 0.974665f, 0.209843f, 0.029890f, 1.709235f, -0.725759f, -0.876951f, 0.522287f, 0.462005f, -1.329269f, -0.295974f, 1.371414f, 0.973846f, 0.765543f, -0.403897f, -0.326279f, 0.748218f, -0.195299f, 0.676756f, -0.080633f, 0.158123f, 0.099984f}; + std::initializer_list X_data{TypeParam(-0.546073f), TypeParam(-0.630178f), TypeParam(-0.634650f), TypeParam(0.974665f), TypeParam(0.209843f), TypeParam(0.029890f), TypeParam(1.709235f), TypeParam(-0.725759f), TypeParam(-0.876951f), TypeParam(0.522287f), TypeParam(0.462005f), TypeParam(-1.329269f), TypeParam(-0.295974f), TypeParam(1.371414f), TypeParam(0.973846f), TypeParam(0.765543f), TypeParam(-0.403897f), TypeParam(-0.326279f), TypeParam(0.748218f), TypeParam(-0.195299f), TypeParam(0.676756f), TypeParam(-0.080633f), TypeParam(0.158123f), TypeParam(0.099984f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{1.182462f, -0.759228f, 0.230068f, -0.103567f, -0.252788f, -0.268017f, 0.762529f, 0.057356f, -1.168338f, -0.708432f, -0.409080f, 0.603860f, -0.776560f, 1.131504f, -0.267275f, -0.215474f, 0.940270f, 0.603129f, 1.017745f, 0.694133f, -0.364025f, -0.796167f, -0.089284f, 0.993165f}; + std::initializer_list Grid_data{TypeParam(1.182462f), TypeParam(-0.759228f), TypeParam(0.230068f), TypeParam(-0.103567f), TypeParam(-0.252788f), TypeParam(-0.268017f), TypeParam(0.762529f), TypeParam(0.057356f), TypeParam(-1.168338f), TypeParam(-0.708432f), TypeParam(-0.409080f), TypeParam(0.603860f), TypeParam(-0.776560f), TypeParam(1.131504f), TypeParam(-0.267275f), TypeParam(-0.215474f), TypeParam(0.940270f), TypeParam(0.603129f), TypeParam(1.017745f), TypeParam(0.694133f), TypeParam(-0.364025f), TypeParam(-0.796167f), TypeParam(-0.089284f), TypeParam(0.993165f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-0.243777f, 0.256440f, -0.179228f, 0.741578f, -0.571899f, 0.031558f, -0.425264f, 0.007242f, -0.044977f, 0.271677f, 0.955187f, -0.224230f, -0.395226f, 0.771988f, 0.108104f, 0.007673f, 0.371491f, -0.360026f, 0.151628f, 0.399982f, 0.038327f, 0.044739f, 0.445689f, 0.133017f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.243777f), TypeParam(0.256440f), TypeParam(-0.179228f), TypeParam(0.741578f), TypeParam(-0.571899f), TypeParam(0.031558f), TypeParam(-0.425264f), TypeParam(0.007242f), TypeParam(-0.044977f), TypeParam(0.271677f), TypeParam(0.955187f), TypeParam(-0.224230f), TypeParam(-0.395226f), TypeParam(0.771988f), TypeParam(0.108104f), TypeParam(0.007673f), TypeParam(0.371491f), TypeParam(-0.360026f), TypeParam(0.151628f), TypeParam(0.399982f), TypeParam(0.038327f), TypeParam(0.044739f), TypeParam(0.445689f), TypeParam(0.133017f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_border_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_border_no_align_corners) { OpTester test("GridSample", 16); std::string mode = "bilinear"; std::string padding_mode = "border"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-0.873307f, 0.004261f, -1.257887f, -1.084466f, 0.752979f, 0.323648f, -0.275010f, 1.305612f, -0.009480f, -0.831312f, -0.556290f, 2.070567f, 0.710039f, -0.146461f, -0.746745f, 0.725842f, 0.403461f, 0.234374f, 0.173281f, 1.724145f, -0.408946f, 0.782749f, -1.520847f, -0.314686f}; + std::initializer_list X_data{TypeParam(-0.873307f), TypeParam(0.004261f), TypeParam(-1.257887f), TypeParam(-1.084466f), TypeParam(0.752979f), TypeParam(0.323648f), TypeParam(-0.275010f), TypeParam(1.305612f), TypeParam(-0.009480f), TypeParam(-0.831312f), TypeParam(-0.556290f), TypeParam(2.070567f), TypeParam(0.710039f), TypeParam(-0.146461f), TypeParam(-0.746745f), TypeParam(0.725842f), TypeParam(0.403461f), TypeParam(0.234374f), TypeParam(0.173281f), TypeParam(1.724145f), TypeParam(-0.408946f), TypeParam(0.782749f), TypeParam(-1.520847f), TypeParam(-0.314686f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.605180f, 0.169896f, 1.021029f, 0.161312f, -0.555188f, 1.135200f, 0.284017f, -1.170817f, -0.341630f, -0.817401f, 1.052104f, -0.198175f, -1.093830f, -0.075436f, 0.753615f, 0.311761f, 0.379445f, 0.111448f, 0.447382f, -0.292382f, -0.477360f, -1.121650f, -0.904004f, 0.520083f}; + std::initializer_list Grid_data{TypeParam(0.605180f), TypeParam(0.169896f), TypeParam(1.021029f), TypeParam(0.161312f), TypeParam(-0.555188f), TypeParam(1.135200f), TypeParam(0.284017f), TypeParam(-1.170817f), TypeParam(-0.341630f), TypeParam(-0.817401f), TypeParam(1.052104f), TypeParam(-0.198175f), TypeParam(-1.093830f), TypeParam(-0.075436f), TypeParam(0.753615f), TypeParam(0.311761f), TypeParam(0.379445f), TypeParam(0.111448f), TypeParam(0.447382f), TypeParam(-0.292382f), TypeParam(-0.477360f), TypeParam(-1.121650f), TypeParam(-0.904004f), TypeParam(0.520083f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-0.725617f, -0.743749f, 0.752979f, -0.185279f, -0.734326f, -0.760828f, -0.091786f, -0.129152f, -0.556290f, 0.964224f, -0.024687f, -0.196084f, -0.581904f, 0.496011f, 0.499240f, 0.319537f, 0.690648f, 0.150559f, -0.343065f, 0.269544f, 0.455333f, 1.124628f, 0.208392f, -1.276367f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.725617f), TypeParam(-0.743749f), TypeParam(0.752979f), TypeParam(-0.185279f), TypeParam(-0.734326f), TypeParam(-0.760828f), TypeParam(-0.091786f), TypeParam(-0.129152f), TypeParam(-0.556290f), TypeParam(0.964224f), TypeParam(-0.024687f), TypeParam(-0.196084f), TypeParam(-0.581904f), TypeParam(0.496011f), TypeParam(0.499240f), TypeParam(0.319537f), TypeParam(0.690648f), TypeParam(0.150559f), TypeParam(-0.343065f), TypeParam(0.269544f), TypeParam(0.455333f), TypeParam(1.124628f), TypeParam(0.208392f), TypeParam(-1.276367f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_reflection_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_reflection_align_corners) { OpTester test("GridSample", 16); std::string mode = "bilinear"; std::string padding_mode = "reflection"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{0.540757f, -0.947807f, 0.202144f, -0.350748f, 0.545005f, 1.541211f, 0.600239f, -0.338015f, -1.080823f, -1.391537f, -0.352570f, 1.560770f, -0.822488f, -2.140920f, 0.099553f, -0.697505f, 0.665352f, -2.256198f, -1.002236f, -1.395144f, 0.415783f, 0.268104f, -0.151752f, 0.794042f}; + std::initializer_list X_data{TypeParam(0.540757f), TypeParam(-0.947807f), TypeParam(0.202144f), TypeParam(-0.350748f), TypeParam(0.545005f), TypeParam(1.541211f), TypeParam(0.600239f), TypeParam(-0.338015f), TypeParam(-1.080823f), TypeParam(-1.391537f), TypeParam(-0.352570f), TypeParam(1.560770f), TypeParam(-0.822488f), TypeParam(-2.140920f), TypeParam(0.099553f), TypeParam(-0.697505f), TypeParam(0.665352f), TypeParam(-2.256198f), TypeParam(-1.002236f), TypeParam(-1.395144f), TypeParam(0.415783f), TypeParam(0.268104f), TypeParam(-0.151752f), TypeParam(0.794042f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{1.051960f, -0.798975f, -0.129852f, -0.064453f, 0.535452f, 0.820411f, -0.190205f, -0.994177f, 0.594591f, 0.358958f, 0.482039f, -0.740241f, 0.772315f, 1.136586f, 0.104126f, -1.120858f, 0.842388f, -0.889742f, 0.275846f, 0.174381f, -0.561644f, 0.417835f, -1.073319f, 0.273311f}; + std::initializer_list Grid_data{TypeParam(1.051960f), TypeParam(-0.798975f), TypeParam(-0.129852f), TypeParam(-0.064453f), TypeParam(0.535452f), TypeParam(0.820411f), TypeParam(-0.190205f), TypeParam(-0.994177f), TypeParam(0.594591f), TypeParam(0.358958f), TypeParam(0.482039f), TypeParam(-0.740241f), TypeParam(0.772315f), TypeParam(1.136586f), TypeParam(0.104126f), TypeParam(-1.120858f), TypeParam(0.842388f), TypeParam(-0.889742f), TypeParam(0.275846f), TypeParam(0.174381f), TypeParam(-0.561644f), TypeParam(0.417835f), TypeParam(-1.073319f), TypeParam(0.273311f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-0.793997f, -0.042818f, 1.034663f, -0.061725f, 0.327743f, -0.470152f, -0.528701f, -1.125254f, 0.678924f, 0.212033f, -0.430627f, -0.410903f, -1.743740f, -1.404122f, -1.882401f, -0.546577f, -0.033295f, 0.203686f, 0.631537f, -1.031405f, -1.182924f, 0.344248f, 0.246420f, 0.266212f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.793997f), TypeParam(-0.042818f), TypeParam(1.034663f), TypeParam(-0.061725f), TypeParam(0.327743f), TypeParam(-0.470152f), TypeParam(-0.528701f), TypeParam(-1.125254f), TypeParam(0.678924f), TypeParam(0.212033f), TypeParam(-0.430627f), TypeParam(-0.410903f), TypeParam(-1.743740f), TypeParam(-1.404122f), TypeParam(-1.882401f), TypeParam(-0.546577f), TypeParam(-0.033295f), TypeParam(0.203686f), TypeParam(0.631537f), TypeParam(-1.031405f), TypeParam(-1.182924f), TypeParam(0.344248f), TypeParam(0.246420f), TypeParam(0.266212f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_reflection_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_reflection_no_align_corners) { OpTester test("GridSample", 16); std::string mode = "bilinear"; std::string padding_mode = "reflection"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{0.584178f, 1.050431f, 1.285579f, -1.616520f, -0.768962f, -1.220462f, 0.573128f, 0.699197f, -1.654887f, 0.493267f, -0.615042f, 1.311865f, 0.788249f, -1.232951f, 0.454381f, -1.436621f, 0.711631f, 0.554599f, -0.807529f, 1.680131f, 0.597634f, -0.238890f, -0.345997f, 1.770104f}; + std::initializer_list X_data{TypeParam(0.584178f), TypeParam(1.050431f), TypeParam(1.285579f), TypeParam(-1.616520f), TypeParam(-0.768962f), TypeParam(-1.220462f), TypeParam(0.573128f), TypeParam(0.699197f), TypeParam(-1.654887f), TypeParam(0.493267f), TypeParam(-0.615042f), TypeParam(1.311865f), TypeParam(0.788249f), TypeParam(-1.232951f), TypeParam(0.454381f), TypeParam(-1.436621f), TypeParam(0.711631f), TypeParam(0.554599f), TypeParam(-0.807529f), TypeParam(1.680131f), TypeParam(0.597634f), TypeParam(-0.238890f), TypeParam(-0.345997f), TypeParam(1.770104f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.564800f, 1.031186f, 0.795913f, -0.629473f, -0.131544f, -0.377622f, -0.964948f, 0.000496f, 0.902922f, 1.011019f, 0.111961f, 0.272548f, -0.519506f, 0.905811f, -0.499330f, -0.833583f, 0.184792f, 0.719262f, -1.081910f, 1.084761f, 0.431677f, -0.840735f, -0.258489f, 1.041096f}; + std::initializer_list Grid_data{TypeParam(0.564800f), TypeParam(1.031186f), TypeParam(0.795913f), TypeParam(-0.629473f), TypeParam(-0.131544f), TypeParam(-0.377622f), TypeParam(-0.964948f), TypeParam(0.000496f), TypeParam(0.902922f), TypeParam(1.011019f), TypeParam(0.111961f), TypeParam(0.272548f), TypeParam(-0.519506f), TypeParam(0.905811f), TypeParam(-0.499330f), TypeParam(-0.833583f), TypeParam(0.184792f), TypeParam(0.719262f), TypeParam(-1.081910f), TypeParam(1.084761f), TypeParam(0.431677f), TypeParam(-0.840735f), TypeParam(-0.258489f), TypeParam(1.041096f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-1.220462f, 0.901641f, 0.521980f, 1.284051f, -1.220462f, -0.717235f, 1.311865f, 0.687708f, -0.023386f, -1.654114f, 1.311865f, 0.029458f, 0.711631f, 0.786895f, 0.604097f, 0.711631f, -1.094857f, 0.673706f, -0.345997f, -0.805863f, 1.103092f, -0.345997f, 1.510167f, 0.165064f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-1.220462f), TypeParam(0.901641f), TypeParam(0.521980f), TypeParam(1.284051f), TypeParam(-1.220462f), TypeParam(-0.717235f), TypeParam(1.311865f), TypeParam(0.687708f), TypeParam(-0.023386f), TypeParam(-1.654114f), TypeParam(1.311865f), TypeParam(0.029458f), TypeParam(0.711631f), TypeParam(0.786895f), TypeParam(0.604097f), TypeParam(0.711631f), TypeParam(-1.094857f), TypeParam(0.673706f), TypeParam(-0.345997f), TypeParam(-0.805863f), TypeParam(1.103092f), TypeParam(-0.345997f), TypeParam(1.510167f), TypeParam(0.165064f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_zeros_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_zeros_align_corners) { OpTester test("GridSample", 16); std::string mode = "bicubic"; std::string padding_mode = "zeros"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{0.497417f, 0.268522f, 1.476879f, 0.354795f, 1.624709f, 0.593423f, -1.725412f, -0.622016f, -0.466707f, -0.319962f, 0.701868f, 0.494252f, -0.630165f, 0.548236f, 1.042740f, 0.253800f, -2.667303f, 1.379165f, -0.519418f, 0.672783f, -0.005627f, -0.180192f, -0.018395f, 0.998084f}; + std::initializer_list X_data{TypeParam(0.497417f), TypeParam(0.268522f), TypeParam(1.476879f), TypeParam(0.354795f), TypeParam(1.624709f), TypeParam(0.593423f), TypeParam(-1.725412f), TypeParam(-0.622016f), TypeParam(-0.466707f), TypeParam(-0.319962f), TypeParam(0.701868f), TypeParam(0.494252f), TypeParam(-0.630165f), TypeParam(0.548236f), TypeParam(1.042740f), TypeParam(0.253800f), TypeParam(-2.667303f), TypeParam(1.379165f), TypeParam(-0.519418f), TypeParam(0.672783f), TypeParam(-0.005627f), TypeParam(-0.180192f), TypeParam(-0.018395f), TypeParam(0.998084f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.213755f, 0.141747f, -0.562622f, -0.414594f, 0.325025f, -0.834438f, 0.197995f, 0.519270f, -0.472884f, 0.996769f, -0.078973f, 0.544455f, 1.188368f, -0.366802f, 0.652090f, -0.343235f, -0.175288f, -0.203365f, -0.007455f, -0.453322f, 0.281264f, 0.045216f, 0.760668f, -0.242886f}; + std::initializer_list Grid_data{TypeParam(0.213755f), TypeParam(0.141747f), TypeParam(-0.562622f), TypeParam(-0.414594f), TypeParam(0.325025f), TypeParam(-0.834438f), TypeParam(0.197995f), TypeParam(0.519270f), TypeParam(-0.472884f), TypeParam(0.996769f), TypeParam(-0.078973f), TypeParam(0.544455f), TypeParam(1.188368f), TypeParam(-0.366802f), TypeParam(0.652090f), TypeParam(-0.343235f), TypeParam(-0.175288f), TypeParam(-0.203365f), TypeParam(-0.007455f), TypeParam(-0.453322f), TypeParam(0.281264f), TypeParam(0.045216f), TypeParam(0.760668f), TypeParam(-0.242886f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{1.007407f, 1.068583f, 0.492134f, 1.222040f, 1.576835f, 1.464183f, -0.238652f, -1.242164f, -1.156880f, 0.279082f, 0.744912f, 0.338287f, 0.215322f, 0.388598f, 0.866571f, 0.556826f, 0.608617f, 0.326312f, 0.044527f, -0.028766f, -0.136528f, -0.084880f, -0.121429f, -0.105516f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(1.007407f), TypeParam(1.068583f), TypeParam(0.492134f), TypeParam(1.222040f), TypeParam(1.576835f), TypeParam(1.464183f), TypeParam(-0.238652f), TypeParam(-1.242164f), TypeParam(-1.156880f), TypeParam(0.279082f), TypeParam(0.744912f), TypeParam(0.338287f), TypeParam(0.215322f), TypeParam(0.388598f), TypeParam(0.866571f), TypeParam(0.556826f), TypeParam(0.608617f), TypeParam(0.326312f), TypeParam(0.044527f), TypeParam(-0.028766f), TypeParam(-0.136528f), TypeParam(-0.084880f), TypeParam(-0.121429f), TypeParam(-0.105516f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_zeros_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_zeros_no_align_corners) { OpTester test("GridSample", 16); std::string mode = "bicubic"; std::string padding_mode = "zeros"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-1.065470f, 0.402578f, -0.405242f, -0.583366f, -0.258523f, -0.605559f, -0.188242f, 0.959607f, 1.189619f, -0.179522f, -1.823240f, -0.051351f, -1.636092f, -2.510569f, -1.238273f, -0.929619f, -0.058536f, 0.772879f, 0.468944f, 0.259886f, 0.757624f, -2.041813f, -0.552378f, 0.626977f}; + std::initializer_list X_data{TypeParam(-1.065470f), TypeParam(0.402578f), TypeParam(-0.405242f), TypeParam(-0.583366f), TypeParam(-0.258523f), TypeParam(-0.605559f), TypeParam(-0.188242f), TypeParam(0.959607f), TypeParam(1.189619f), TypeParam(-0.179522f), TypeParam(-1.823240f), TypeParam(-0.051351f), TypeParam(-1.636092f), TypeParam(-2.510569f), TypeParam(-1.238273f), TypeParam(-0.929619f), TypeParam(-0.058536f), TypeParam(0.772879f), TypeParam(0.468944f), TypeParam(0.259886f), TypeParam(0.757624f), TypeParam(-2.041813f), TypeParam(-0.552378f), TypeParam(0.626977f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{-1.199809f, 0.061445f, -0.035546f, 0.180524f, 0.919500f, 1.166411f, -0.711939f, -0.074825f, -0.480808f, -1.105975f, -0.873191f, 1.126273f, 0.699673f, 0.644581f, 0.666892f, -0.953375f, 0.126023f, 1.116858f, -0.669703f, 1.067513f, 0.315406f, 0.844252f, -0.514065f, 0.553221f}; + std::initializer_list Grid_data{TypeParam(-1.199809f), TypeParam(0.061445f), TypeParam(-0.035546f), TypeParam(0.180524f), TypeParam(0.919500f), TypeParam(1.166411f), TypeParam(-0.711939f), TypeParam(-0.074825f), TypeParam(-0.480808f), TypeParam(-1.105975f), TypeParam(-0.873191f), TypeParam(1.126273f), TypeParam(0.699673f), TypeParam(0.644581f), TypeParam(0.666892f), TypeParam(-0.953375f), TypeParam(0.126023f), TypeParam(1.116858f), TypeParam(-0.669703f), TypeParam(1.067513f), TypeParam(0.315406f), TypeParam(0.844252f), TypeParam(-0.514065f), TypeParam(0.553221f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-0.086429f, -0.590424f, -0.090572f, -0.393926f, -0.379182f, -0.031455f, 0.347836f, 0.182097f, 0.050161f, 1.154870f, -0.134312f, -0.509844f, 0.697346f, -1.440179f, 0.264668f, 0.021389f, 0.729883f, -0.236038f, 0.576661f, 0.348301f, 0.149351f, -0.327477f, 0.607344f, -0.405680f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.086429f), TypeParam(-0.590424f), TypeParam(-0.090572f), TypeParam(-0.393926f), TypeParam(-0.379182f), TypeParam(-0.031455f), TypeParam(0.347836f), TypeParam(0.182097f), TypeParam(0.050161f), TypeParam(1.154870f), TypeParam(-0.134312f), TypeParam(-0.509844f), TypeParam(0.697346f), TypeParam(-1.440179f), TypeParam(0.264668f), TypeParam(0.021389f), TypeParam(0.729883f), TypeParam(-0.236038f), TypeParam(0.576661f), TypeParam(0.348301f), TypeParam(0.149351f), TypeParam(-0.327477f), TypeParam(0.607344f), TypeParam(-0.405680f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_border_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_border_align_corners) { OpTester test("GridSample", 16); std::string mode = "bicubic"; std::string padding_mode = "border"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{0.203585f, -1.032829f, 1.130481f, -0.570301f, -2.100938f, 0.389922f, 0.087343f, -0.857360f, 1.193520f, -0.019760f, 0.280285f, 1.811013f, 1.838673f, 0.164184f, 1.436009f, 0.167011f, -1.139939f, -0.029833f, -0.009878f, 0.079750f, 0.216590f, -0.265852f, -0.528116f, -0.451915f}; + std::initializer_list X_data{TypeParam(0.203585f), TypeParam(-1.032829f), TypeParam(1.130481f), TypeParam(-0.570301f), TypeParam(-2.100938f), TypeParam(0.389922f), TypeParam(0.087343f), TypeParam(-0.857360f), TypeParam(1.193520f), TypeParam(-0.019760f), TypeParam(0.280285f), TypeParam(1.811013f), TypeParam(1.838673f), TypeParam(0.164184f), TypeParam(1.436009f), TypeParam(0.167011f), TypeParam(-1.139939f), TypeParam(-0.029833f), TypeParam(-0.009878f), TypeParam(0.079750f), TypeParam(0.216590f), TypeParam(-0.265852f), TypeParam(-0.528116f), TypeParam(-0.451915f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.797796f, -1.010726f, 0.868577f, -1.132977f, 0.268082f, -0.786042f, -0.476635f, 0.212483f, -0.471816f, -0.189867f, -1.137389f, -1.131448f, 0.464836f, -0.507934f, -0.730068f, -0.473499f, -0.981082f, -0.959280f, 0.718047f, 0.609891f, 0.159844f, -0.655512f, 0.399241f, 0.053910f}; + std::initializer_list Grid_data{TypeParam(0.797796f), TypeParam(-1.010726f), TypeParam(0.868577f), TypeParam(-1.132977f), TypeParam(0.268082f), TypeParam(-0.786042f), TypeParam(-0.476635f), TypeParam(0.212483f), TypeParam(-0.471816f), TypeParam(-0.189867f), TypeParam(-1.137389f), TypeParam(-1.131448f), TypeParam(0.464836f), TypeParam(-0.507934f), TypeParam(-0.730068f), TypeParam(-0.473499f), TypeParam(-0.981082f), TypeParam(-0.959280f), TypeParam(0.718047f), TypeParam(0.609891f), TypeParam(0.159844f), TypeParam(-0.655512f), TypeParam(0.399241f), TypeParam(0.053910f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-0.934180f, -1.004565f, -0.467118f, 0.384839f, 0.792549f, 0.188357f, -0.785741f, -0.871727f, -0.372851f, 0.958270f, 0.751528f, 0.046397f, 0.598629f, 1.686400f, 1.817043f, 0.015806f, 0.866266f, 0.480930f, -0.013358f, 0.152904f, -0.001292f, -0.385043f, 0.030959f, -0.152332f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.934180f), TypeParam(-1.004565f), TypeParam(-0.467118f), TypeParam(0.384839f), TypeParam(0.792549f), TypeParam(0.188357f), TypeParam(-0.785741f), TypeParam(-0.871727f), TypeParam(-0.372851f), TypeParam(0.958270f), TypeParam(0.751528f), TypeParam(0.046397f), TypeParam(0.598629f), TypeParam(1.686400f), TypeParam(1.817043f), TypeParam(0.015806f), TypeParam(0.866266f), TypeParam(0.480930f), TypeParam(-0.013358f), TypeParam(0.152904f), TypeParam(-0.001292f), TypeParam(-0.385043f), TypeParam(0.030959f), TypeParam(-0.152332f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_border_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_border_no_align_corners) { OpTester test("GridSample", 16); std::string mode = "bicubic"; std::string padding_mode = "border"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-0.427361f, 0.814325f, -1.412076f, -0.099774f, 0.074936f, 0.590322f, 0.398556f, -0.635891f, -1.081747f, -0.330179f, 0.271759f, -1.089819f, -0.746656f, -0.942538f, -1.251568f, -1.730282f, -0.722323f, 0.525964f, -0.436259f, -0.188952f, -0.499550f, 1.502071f, -0.014112f, 1.194050f}; + std::initializer_list X_data{TypeParam(-0.427361f), TypeParam(0.814325f), TypeParam(-1.412076f), TypeParam(-0.099774f), TypeParam(0.074936f), TypeParam(0.590322f), TypeParam(0.398556f), TypeParam(-0.635891f), TypeParam(-1.081747f), TypeParam(-0.330179f), TypeParam(0.271759f), TypeParam(-1.089819f), TypeParam(-0.746656f), TypeParam(-0.942538f), TypeParam(-1.251568f), TypeParam(-1.730282f), TypeParam(-0.722323f), TypeParam(0.525964f), TypeParam(-0.436259f), TypeParam(-0.188952f), TypeParam(-0.499550f), TypeParam(1.502071f), TypeParam(-0.014112f), TypeParam(1.194050f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{-0.102021f, -0.935855f, -0.007380f, -0.996053f, -0.258157f, 0.695455f, -0.834420f, -0.808862f, -0.293012f, -0.328961f, 0.203145f, 0.199219f, 0.608516f, -0.826657f, -0.084685f, 0.671149f, 1.037966f, -0.087535f, -0.694344f, 0.344955f, 0.683373f, -0.749700f, -0.696352f, 0.530398f}; + std::initializer_list Grid_data{TypeParam(-0.102021f), TypeParam(-0.935855f), TypeParam(-0.007380f), TypeParam(-0.996053f), TypeParam(-0.258157f), TypeParam(0.695455f), TypeParam(-0.834420f), TypeParam(-0.808862f), TypeParam(-0.293012f), TypeParam(-0.328961f), TypeParam(0.203145f), TypeParam(0.199219f), TypeParam(0.608516f), TypeParam(-0.826657f), TypeParam(-0.084685f), TypeParam(0.671149f), TypeParam(1.037966f), TypeParam(-0.087535f), TypeParam(-0.694344f), TypeParam(0.344955f), TypeParam(0.683373f), TypeParam(-0.749700f), TypeParam(-0.696352f), TypeParam(0.530398f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{0.154701f, 0.273277f, 0.226316f, -0.467055f, -0.820643f, -0.311691f, 0.084699f, -0.052970f, 0.001158f, 0.679701f, -0.467804f, -0.607116f, -0.871407f, -0.210613f, -1.860685f, -1.059387f, -0.902250f, -0.918798f, -0.360562f, 0.476049f, 1.499304f, -0.418396f, -0.298854f, -0.235927f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(0.154701f), TypeParam(0.273277f), TypeParam(0.226316f), TypeParam(-0.467055f), TypeParam(-0.820643f), TypeParam(-0.311691f), TypeParam(0.084699f), TypeParam(-0.052970f), TypeParam(0.001158f), TypeParam(0.679701f), TypeParam(-0.467804f), TypeParam(-0.607116f), TypeParam(-0.871407f), TypeParam(-0.210613f), TypeParam(-1.860685f), TypeParam(-1.059387f), TypeParam(-0.902250f), TypeParam(-0.918798f), TypeParam(-0.360562f), TypeParam(0.476049f), TypeParam(1.499304f), TypeParam(-0.418396f), TypeParam(-0.298854f), TypeParam(-0.235927f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_reflection_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_reflection_align_corners) { OpTester test("GridSample", 16); std::string mode = "bicubic"; std::string padding_mode = "reflection"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-1.084082f, -0.128738f, -0.681077f, -1.309896f, 0.660269f, -1.412063f, 1.834581f, 0.456195f, 0.162801f, -0.638266f, 0.897973f, -0.383653f, 0.297945f, 1.809414f, -0.091298f, 1.092744f, -0.102453f, -1.726535f, -0.484632f, 0.712097f, 1.820312f, -0.852073f, -0.341399f, -0.138106f}; + std::initializer_list X_data{TypeParam(-1.084082f), TypeParam(-0.128738f), TypeParam(-0.681077f), TypeParam(-1.309896f), TypeParam(0.660269f), TypeParam(-1.412063f), TypeParam(1.834581f), TypeParam(0.456195f), TypeParam(0.162801f), TypeParam(-0.638266f), TypeParam(0.897973f), TypeParam(-0.383653f), TypeParam(0.297945f), TypeParam(1.809414f), TypeParam(-0.091298f), TypeParam(1.092744f), TypeParam(-0.102453f), TypeParam(-1.726535f), TypeParam(-0.484632f), TypeParam(0.712097f), TypeParam(1.820312f), TypeParam(-0.852073f), TypeParam(-0.341399f), TypeParam(-0.138106f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{-0.501236f, -0.770480f, -0.140656f, -1.129896f, 0.470370f, 0.885106f, 0.288068f, -0.118568f, 0.594968f, -0.761702f, 1.173892f, -1.193212f, -1.149534f, -0.283562f, 0.980213f, 0.120151f, 0.460855f, -0.879608f, 0.437623f, -0.134092f, 0.480988f, 0.847491f, 0.521616f, -0.102077f}; + std::initializer_list Grid_data{TypeParam(-0.501236f), TypeParam(-0.770480f), TypeParam(-0.140656f), TypeParam(-1.129896f), TypeParam(0.470370f), TypeParam(0.885106f), TypeParam(0.288068f), TypeParam(-0.118568f), TypeParam(0.594968f), TypeParam(-0.761702f), TypeParam(1.173892f), TypeParam(-1.193212f), TypeParam(-1.149534f), TypeParam(-0.283562f), TypeParam(0.980213f), TypeParam(0.120151f), TypeParam(0.460855f), TypeParam(-0.879608f), TypeParam(0.437623f), TypeParam(-0.134092f), TypeParam(0.480988f), TypeParam(0.847491f), TypeParam(0.521616f), TypeParam(-0.102077f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-0.953278f, -0.722872f, -1.065112f, -1.071529f, -0.344328f, -0.233562f, 1.436462f, 1.232983f, -0.181487f, -0.297043f, 0.464837f, 0.396673f, 0.053896f, 0.733510f, 1.541248f, 1.117701f, -1.352406f, 1.131762f, 1.324986f, -0.882173f, 0.469635f, -0.247133f, -0.196824f, -0.393592f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.953278f), TypeParam(-0.722872f), TypeParam(-1.065112f), TypeParam(-1.071529f), TypeParam(-0.344328f), TypeParam(-0.233562f), TypeParam(1.436462f), TypeParam(1.232983f), TypeParam(-0.181487f), TypeParam(-0.297043f), TypeParam(0.464837f), TypeParam(0.396673f), TypeParam(0.053896f), TypeParam(0.733510f), TypeParam(1.541248f), TypeParam(1.117701f), TypeParam(-1.352406f), TypeParam(1.131762f), TypeParam(1.324986f), TypeParam(-0.882173f), TypeParam(0.469635f), TypeParam(-0.247133f), TypeParam(-0.196824f), TypeParam(-0.393592f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_reflection_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_reflection_no_align_corners) { OpTester test("GridSample", 16); std::string mode = "bicubic"; std::string padding_mode = "reflection"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-1.122981f, 0.620969f, -0.876394f, -1.774003f, -0.810376f, -1.475962f, 0.667025f, 0.668804f, -0.748346f, 1.422400f, 0.138469f, -0.165945f, 1.266886f, -0.496157f, 0.158060f, 0.488900f, 0.414476f, 0.419527f, 0.238000f, -0.034674f, 0.229435f, 0.234530f, 0.320846f, 0.703888f}; + std::initializer_list X_data{TypeParam(-1.122981f), TypeParam(0.620969f), TypeParam(-0.876394f), TypeParam(-1.774003f), TypeParam(-0.810376f), TypeParam(-1.475962f), TypeParam(0.667025f), TypeParam(0.668804f), TypeParam(-0.748346f), TypeParam(1.422400f), TypeParam(0.138469f), TypeParam(-0.165945f), TypeParam(1.266886f), TypeParam(-0.496157f), TypeParam(0.158060f), TypeParam(0.488900f), TypeParam(0.414476f), TypeParam(0.419527f), TypeParam(0.238000f), TypeParam(-0.034674f), TypeParam(0.229435f), TypeParam(0.234530f), TypeParam(0.320846f), TypeParam(0.703888f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.471637f, -0.923628f, -0.909401f, 0.684338f, 0.224360f, 1.092855f, -0.320755f, -0.579618f, -0.111056f, 0.006071f, 0.915173f, -1.195296f, -0.085441f, 0.530823f, -0.660820f, -0.609769f, 0.579921f, -1.149822f, 0.284347f, -0.929024f, 0.596474f, -1.026049f, 0.737766f, -1.135959f}; + std::initializer_list Grid_data{TypeParam(0.471637f), TypeParam(-0.923628f), TypeParam(-0.909401f), TypeParam(0.684338f), TypeParam(0.224360f), TypeParam(1.092855f), TypeParam(-0.320755f), TypeParam(-0.579618f), TypeParam(-0.111056f), TypeParam(0.006071f), TypeParam(0.915173f), TypeParam(-1.195296f), TypeParam(-0.085441f), TypeParam(0.530823f), TypeParam(-0.660820f), TypeParam(-0.609769f), TypeParam(0.579921f), TypeParam(-1.149822f), TypeParam(0.284347f), TypeParam(-0.929024f), TypeParam(0.596474f), TypeParam(-1.026049f), TypeParam(0.737766f), TypeParam(-1.135959f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{0.998063f, -0.689213f, -1.266024f, -0.870706f, -1.217616f, 1.292693f, 0.543307f, 0.219521f, -0.255151f, 0.543599f, 0.062982f, 0.527696f, 0.387590f, 1.352544f, -0.758053f, -0.262859f, -0.820496f, -0.934255f, 0.434353f, 0.262797f, -0.092283f, -0.021089f, -0.106052f, -0.119717f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(0.998063f), TypeParam(-0.689213f), TypeParam(-1.266024f), TypeParam(-0.870706f), TypeParam(-1.217616f), TypeParam(1.292693f), TypeParam(0.543307f), TypeParam(0.219521f), TypeParam(-0.255151f), TypeParam(0.543599f), TypeParam(0.062982f), TypeParam(0.527696f), TypeParam(0.387590f), TypeParam(1.352544f), TypeParam(-0.758053f), TypeParam(-0.262859f), TypeParam(-0.820496f), TypeParam(-0.934255f), TypeParam(0.434353f), TypeParam(0.262797f), TypeParam(-0.092283f), TypeParam(-0.021089f), TypeParam(-0.106052f), TypeParam(-0.119717f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_20_4D_nearest_zeros_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_nearest_zeros_align_corners) { OpTester test("GridSample", 20); std::string mode = "nearest"; std::string padding_mode = "zeros"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{0.404710f, -0.654932f, 0.052124f, 0.340055f, -0.212416f, 1.562917f, -0.907159f, -1.566185f, 0.596746f, 1.002548f, -0.820504f, 0.509186f, 0.951389f, 0.773736f, -2.144711f, 0.044147f, 1.290612f, 0.664926f, 0.530731f, -0.423196f, -0.388699f, 0.333224f, 0.293744f, -0.157543f}; + std::initializer_list X_data{TypeParam(0.404710f), TypeParam(-0.654932f), TypeParam(0.052124f), TypeParam(0.340055f), TypeParam(-0.212416f), TypeParam(1.562917f), TypeParam(-0.907159f), TypeParam(-1.566185f), TypeParam(0.596746f), TypeParam(1.002548f), TypeParam(-0.820504f), TypeParam(0.509186f), TypeParam(0.951389f), TypeParam(0.773736f), TypeParam(-2.144711f), TypeParam(0.044147f), TypeParam(1.290612f), TypeParam(0.664926f), TypeParam(0.530731f), TypeParam(-0.423196f), TypeParam(-0.388699f), TypeParam(0.333224f), TypeParam(0.293744f), TypeParam(-0.157543f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.528957f, 0.982925f, -0.033286f, -0.806271f, 0.793837f, -0.411498f, 0.621343f, -0.295724f, 0.510113f, 1.079311f, 1.115827f, -1.092078f, -0.793776f, -0.496160f, -0.765241f, 1.151400f, -0.105983f, -0.796009f, -0.533987f, -0.662838f, 0.489587f, -1.046701f, -1.118884f, -1.182913f}; + std::initializer_list Grid_data{TypeParam(0.528957f), TypeParam(0.982925f), TypeParam(-0.033286f), TypeParam(-0.806271f), TypeParam(0.793837f), TypeParam(-0.411498f), TypeParam(0.621343f), TypeParam(-0.295724f), TypeParam(0.510113f), TypeParam(1.079311f), TypeParam(1.115827f), TypeParam(-1.092078f), TypeParam(-0.793776f), TypeParam(-0.496160f), TypeParam(-0.765241f), TypeParam(1.151400f), TypeParam(-0.105983f), TypeParam(-0.796009f), TypeParam(-0.533987f), TypeParam(-0.662838f), TypeParam(0.489587f), TypeParam(-1.046701f), TypeParam(-1.118884f), TypeParam(-1.182913f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{1.562917f, 0.404710f, 0.340055f, 0.340055f, 1.562917f, -0.654932f, 0.509186f, -0.907159f, 1.002548f, 1.002548f, 0.509186f, -1.566185f, -2.144711f, 1.290612f, 0.951389f, 0.951389f, 0.773736f, 0.951389f, -0.388699f, 0.293744f, 0.530731f, 0.530731f, -0.423196f, 0.530731f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(1.562917f), TypeParam(0.404710f), TypeParam(0.340055f), TypeParam(0.340055f), TypeParam(1.562917f), TypeParam(-0.654932f), TypeParam(0.509186f), TypeParam(-0.907159f), TypeParam(1.002548f), TypeParam(1.002548f), TypeParam(0.509186f), TypeParam(-1.566185f), TypeParam(-2.144711f), TypeParam(1.290612f), TypeParam(0.951389f), TypeParam(0.951389f), TypeParam(0.773736f), TypeParam(0.951389f), TypeParam(-0.388699f), TypeParam(0.293744f), TypeParam(0.530731f), TypeParam(0.530731f), TypeParam(-0.423196f), TypeParam(0.530731f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_5D_nearest_zeros_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_nearest_zeros_align_corners) { OpTester test("GridSample", 20); std::string mode = "nearest"; std::string padding_mode = "zeros"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 3, 2}; - std::initializer_list X_data{-1.495959f, 0.018231f, 0.345600f, 0.031206f, 0.400390f, 0.425763f, 0.839517f, 1.238945f, 0.523906f, -1.658372f, 0.548335f, -1.398321f, -1.976414f, 1.232491f, -0.545575f, -0.069414f, 0.732245f, -0.150333f, -0.707132f, 0.467497f, 0.278677f, 1.335679f, 1.155313f, -0.056298f, 0.430615f, -0.932645f, -1.505319f, 0.103317f, 1.521579f, 0.365497f, 1.428928f, 0.364333f, 1.683777f, 1.010632f, 0.621895f, 2.284701f, 1.574905f, -0.310514f, 1.495724f, 1.003370f, -1.437482f, 0.043097f, -1.645546f, -1.464643f, 0.350139f, -0.105905f, -0.740495f, 1.157691f, 1.443377f, 0.198399f, -1.105180f, -2.037115f, 2.128767f, -0.204457f, 0.468464f, 1.203629f, -0.362309f, -0.130520f, 1.532353f, 1.547599f, -0.831847f, -1.008509f, 0.023218f, 0.342626f, -0.882915f, 0.560640f, -1.142297f, 1.119107f, 0.385787f, -0.068515f, -0.529550f, -0.233903f}; + std::initializer_list X_data{TypeParam(-1.495959f), TypeParam(0.018231f), TypeParam(0.345600f), TypeParam(0.031206f), TypeParam(0.400390f), TypeParam(0.425763f), TypeParam(0.839517f), TypeParam(1.238945f), TypeParam(0.523906f), TypeParam(-1.658372f), TypeParam(0.548335f), TypeParam(-1.398321f), TypeParam(-1.976414f), TypeParam(1.232491f), TypeParam(-0.545575f), TypeParam(-0.069414f), TypeParam(0.732245f), TypeParam(-0.150333f), TypeParam(-0.707132f), TypeParam(0.467497f), TypeParam(0.278677f), TypeParam(1.335679f), TypeParam(1.155313f), TypeParam(-0.056298f), TypeParam(0.430615f), TypeParam(-0.932645f), TypeParam(-1.505319f), TypeParam(0.103317f), TypeParam(1.521579f), TypeParam(0.365497f), TypeParam(1.428928f), TypeParam(0.364333f), TypeParam(1.683777f), TypeParam(1.010632f), TypeParam(0.621895f), TypeParam(2.284701f), TypeParam(1.574905f), TypeParam(-0.310514f), TypeParam(1.495724f), TypeParam(1.003370f), TypeParam(-1.437482f), TypeParam(0.043097f), TypeParam(-1.645546f), TypeParam(-1.464643f), TypeParam(0.350139f), TypeParam(-0.105905f), TypeParam(-0.740495f), TypeParam(1.157691f), TypeParam(1.443377f), TypeParam(0.198399f), TypeParam(-1.105180f), TypeParam(-2.037115f), TypeParam(2.128767f), TypeParam(-0.204457f), TypeParam(0.468464f), TypeParam(1.203629f), TypeParam(-0.362309f), TypeParam(-0.130520f), TypeParam(1.532353f), TypeParam(1.547599f), TypeParam(-0.831847f), TypeParam(-1.008509f), TypeParam(0.023218f), TypeParam(0.342626f), TypeParam(-0.882915f), TypeParam(0.560640f), TypeParam(-1.142297f), TypeParam(1.119107f), TypeParam(0.385787f), TypeParam(-0.068515f), TypeParam(-0.529550f), TypeParam(-0.233903f)}; std::initializer_list Grid_shape{2, 3, 3, 2, 3}; - std::initializer_list Grid_data{0.812645f, 0.528235f, -0.550793f, -0.856977f, -1.073535f, 0.059526f, 1.163856f, -0.227931f, -0.050518f, -0.872033f, 0.368412f, 0.760780f, -1.183099f, -0.844947f, 0.888849f, 0.284117f, -0.074815f, 0.214510f, -0.182450f, -0.838758f, -1.121316f, 0.789250f, -0.142724f, -0.445665f, -0.309738f, -0.654508f, -0.355420f, -1.030097f, 0.898012f, 0.490011f, -0.605186f, -0.409576f, 0.538365f, -0.444367f, 0.316432f, 0.330410f, -0.755392f, 0.300602f, 0.073421f, 1.048061f, -0.434184f, -0.308482f, 1.033921f, -0.979923f, 0.086698f, 1.156203f, -0.538042f, 1.150419f, 1.064809f, 1.116408f, -0.114508f, 1.085560f, -0.522863f, -0.410766f, 0.453879f, 0.253497f, 0.661531f, 1.140383f, -0.751187f, 0.636872f, 0.401477f, 0.633082f, 0.569007f, -0.448884f, -0.948427f, 0.960462f, -0.684283f, 0.767193f, -1.143172f, -0.207603f, 0.012719f, 0.207628f, 0.096998f, 0.378128f, -0.133613f, 0.293885f, 1.187501f, -0.776462f, -0.065516f, -0.458068f, 1.052916f, 1.027248f, -0.032723f, -0.415959f, -0.741439f, 0.858648f, -0.082636f, 1.130172f, 0.684314f, 1.050365f, 0.949108f, -0.779811f, 0.351243f, -0.497591f, 0.602104f, -0.107892f, 0.103884f, -0.829931f, -1.072471f, 0.451888f, 0.278862f, 0.104235f, 0.815033f, -0.501089f, 0.425977f, -0.660914f, 0.248640f, -0.273958f}; + std::initializer_list Grid_data{TypeParam(0.812645f), TypeParam(0.528235f), TypeParam(-0.550793f), TypeParam(-0.856977f), TypeParam(-1.073535f), TypeParam(0.059526f), TypeParam(1.163856f), TypeParam(-0.227931f), TypeParam(-0.050518f), TypeParam(-0.872033f), TypeParam(0.368412f), TypeParam(0.760780f), TypeParam(-1.183099f), TypeParam(-0.844947f), TypeParam(0.888849f), TypeParam(0.284117f), TypeParam(-0.074815f), TypeParam(0.214510f), TypeParam(-0.182450f), TypeParam(-0.838758f), TypeParam(-1.121316f), TypeParam(0.789250f), TypeParam(-0.142724f), TypeParam(-0.445665f), TypeParam(-0.309738f), TypeParam(-0.654508f), TypeParam(-0.355420f), TypeParam(-1.030097f), TypeParam(0.898012f), TypeParam(0.490011f), TypeParam(-0.605186f), TypeParam(-0.409576f), TypeParam(0.538365f), TypeParam(-0.444367f), TypeParam(0.316432f), TypeParam(0.330410f), TypeParam(-0.755392f), TypeParam(0.300602f), TypeParam(0.073421f), TypeParam(1.048061f), TypeParam(-0.434184f), TypeParam(-0.308482f), TypeParam(1.033921f), TypeParam(-0.979923f), TypeParam(0.086698f), TypeParam(1.156203f), TypeParam(-0.538042f), TypeParam(1.150419f), TypeParam(1.064809f), TypeParam(1.116408f), TypeParam(-0.114508f), TypeParam(1.085560f), TypeParam(-0.522863f), TypeParam(-0.410766f), TypeParam(0.453879f), TypeParam(0.253497f), TypeParam(0.661531f), TypeParam(1.140383f), TypeParam(-0.751187f), TypeParam(0.636872f), TypeParam(0.401477f), TypeParam(0.633082f), TypeParam(0.569007f), TypeParam(-0.448884f), TypeParam(-0.948427f), TypeParam(0.960462f), TypeParam(-0.684283f), TypeParam(0.767193f), TypeParam(-1.143172f), TypeParam(-0.207603f), TypeParam(0.012719f), TypeParam(0.207628f), TypeParam(0.096998f), TypeParam(0.378128f), TypeParam(-0.133613f), TypeParam(0.293885f), TypeParam(1.187501f), TypeParam(-0.776462f), TypeParam(-0.065516f), TypeParam(-0.458068f), TypeParam(1.052916f), TypeParam(1.027248f), TypeParam(-0.032723f), TypeParam(-0.415959f), TypeParam(-0.741439f), TypeParam(0.858648f), TypeParam(-0.082636f), TypeParam(1.130172f), TypeParam(0.684314f), TypeParam(1.050365f), TypeParam(0.949108f), TypeParam(-0.779811f), TypeParam(0.351243f), TypeParam(-0.497591f), TypeParam(0.602104f), TypeParam(-0.107892f), TypeParam(0.103884f), TypeParam(-0.829931f), TypeParam(-1.072471f), TypeParam(0.451888f), TypeParam(0.278862f), TypeParam(0.104235f), TypeParam(0.815033f), TypeParam(-0.501089f), TypeParam(0.425977f), TypeParam(-0.660914f), TypeParam(0.248640f), TypeParam(-0.273958f)}; std::initializer_list Y_shape{2, 2, 3, 3, 2}; - std::initializer_list Y_data{0.425763f, 0.839517f, -1.658372f, -0.545575f, -1.976414f, -1.658372f, -1.495959f, -1.658372f, 0.839517f, 0.548335f, -0.545575f, 0.523906f, 0.523906f, -1.658372f, 1.238945f, 1.232491f, -1.398321f, 1.238945f, -0.056298f, 0.430615f, 0.103317f, 1.683777f, 1.428928f, 0.103317f, -0.707132f, 0.103317f, 0.430615f, 1.521579f, 1.683777f, -1.505319f, -1.505319f, 0.103317f, -0.932645f, 0.364333f, 0.365497f, -0.932645f, -2.037115f, 0.198399f, -0.204457f, 1.443377f, -1.437482f, 0.350139f, -0.105905f, 0.043097f, -1.105180f, -0.105905f, -0.740495f, -0.204457f, -1.464643f, -0.740495f, -0.310514f, -0.105905f, -1.464643f, 0.350139f, -0.068515f, 1.119107f, -0.233903f, -1.142297f, 1.532353f, 0.023218f, 0.342626f, 1.547599f, 0.385787f, 0.342626f, -0.882915f, -0.233903f, -1.008509f, -0.882915f, 1.203629f, 0.342626f, -1.008509f, 0.023218f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(0.425763f), TypeParam(0.839517f), TypeParam(-1.658372f), TypeParam(-0.545575f), TypeParam(-1.976414f), TypeParam(-1.658372f), TypeParam(-1.495959f), TypeParam(-1.658372f), TypeParam(0.839517f), TypeParam(0.548335f), TypeParam(-0.545575f), TypeParam(0.523906f), TypeParam(0.523906f), TypeParam(-1.658372f), TypeParam(1.238945f), TypeParam(1.232491f), TypeParam(-1.398321f), TypeParam(1.238945f), TypeParam(-0.056298f), TypeParam(0.430615f), TypeParam(0.103317f), TypeParam(1.683777f), TypeParam(1.428928f), TypeParam(0.103317f), TypeParam(-0.707132f), TypeParam(0.103317f), TypeParam(0.430615f), TypeParam(1.521579f), TypeParam(1.683777f), TypeParam(-1.505319f), TypeParam(-1.505319f), TypeParam(0.103317f), TypeParam(-0.932645f), TypeParam(0.364333f), TypeParam(0.365497f), TypeParam(-0.932645f), TypeParam(-2.037115f), TypeParam(0.198399f), TypeParam(-0.204457f), TypeParam(1.443377f), TypeParam(-1.437482f), TypeParam(0.350139f), TypeParam(-0.105905f), TypeParam(0.043097f), TypeParam(-1.105180f), TypeParam(-0.105905f), TypeParam(-0.740495f), TypeParam(-0.204457f), TypeParam(-1.464643f), TypeParam(-0.740495f), TypeParam(-0.310514f), TypeParam(-0.105905f), TypeParam(-1.464643f), TypeParam(0.350139f), TypeParam(-0.068515f), TypeParam(1.119107f), TypeParam(-0.233903f), TypeParam(-1.142297f), TypeParam(1.532353f), TypeParam(0.023218f), TypeParam(0.342626f), TypeParam(1.547599f), TypeParam(0.385787f), TypeParam(0.342626f), TypeParam(-0.882915f), TypeParam(-0.233903f), TypeParam(-1.008509f), TypeParam(-0.882915f), TypeParam(1.203629f), TypeParam(0.342626f), TypeParam(-1.008509f), TypeParam(0.023218f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_4D_nearest_zeros_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_nearest_zeros_no_align_corners) { OpTester test("GridSample", 20); std::string mode = "nearest"; std::string padding_mode = "zeros"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-1.948141f, 1.836740f, -0.418393f, -0.125621f, 1.779137f, -0.028049f, 0.367697f, -0.388847f, -0.939514f, -0.129193f, -0.101240f, -3.087570f, -0.778617f, 1.026859f, 0.624162f, 0.291416f, 0.580998f, -0.185200f, 0.333020f, 0.415896f, 0.011702f, 0.014502f, -0.722870f, -0.201041f}; + std::initializer_list X_data{TypeParam(-1.948141f), TypeParam(1.836740f), TypeParam(-0.418393f), TypeParam(-0.125621f), TypeParam(1.779137f), TypeParam(-0.028049f), TypeParam(0.367697f), TypeParam(-0.388847f), TypeParam(-0.939514f), TypeParam(-0.129193f), TypeParam(-0.101240f), TypeParam(-3.087570f), TypeParam(-0.778617f), TypeParam(1.026859f), TypeParam(0.624162f), TypeParam(0.291416f), TypeParam(0.580998f), TypeParam(-0.185200f), TypeParam(0.333020f), TypeParam(0.415896f), TypeParam(0.011702f), TypeParam(0.014502f), TypeParam(-0.722870f), TypeParam(-0.201041f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.818167f, -0.394078f, 0.627076f, -1.124307f, -0.296864f, -0.244061f, -0.423780f, 0.504000f, -0.546789f, -0.139085f, -0.346504f, -1.126900f, -0.198169f, -1.016972f, 0.699725f, 0.641356f, 1.124151f, -0.402963f, 0.061023f, 0.235069f, 1.197862f, 1.099936f, -0.621047f, -1.021083f}; + std::initializer_list Grid_data{TypeParam(0.818167f), TypeParam(-0.394078f), TypeParam(0.627076f), TypeParam(-1.124307f), TypeParam(-0.296864f), TypeParam(-0.244061f), TypeParam(-0.423780f), TypeParam(0.504000f), TypeParam(-0.546789f), TypeParam(-0.139085f), TypeParam(-0.346504f), TypeParam(-1.126900f), TypeParam(-0.198169f), TypeParam(-1.016972f), TypeParam(0.699725f), TypeParam(0.641356f), TypeParam(1.124151f), TypeParam(-0.402963f), TypeParam(0.061023f), TypeParam(0.235069f), TypeParam(1.197862f), TypeParam(1.099936f), TypeParam(-0.621047f), TypeParam(-1.021083f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{1.836740f, 0.000000f, -0.418393f, 1.779137f, -0.418393f, 0.000000f, -0.388847f, 0.000000f, -0.939514f, -0.101240f, -0.939514f, 0.000000f, 0.000000f, -0.185200f, 0.000000f, 0.291416f, 0.000000f, 0.000000f, 0.000000f, -0.201041f, 0.000000f, 0.014502f, 0.000000f, 0.000000f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(1.836740f), TypeParam(0.000000f), TypeParam(-0.418393f), TypeParam(1.779137f), TypeParam(-0.418393f), TypeParam(0.000000f), TypeParam(-0.388847f), TypeParam(0.000000f), TypeParam(-0.939514f), TypeParam(-0.101240f), TypeParam(-0.939514f), TypeParam(0.000000f), TypeParam(0.000000f), TypeParam(-0.185200f), TypeParam(0.000000f), TypeParam(0.291416f), TypeParam(0.000000f), TypeParam(0.000000f), TypeParam(0.000000f), TypeParam(-0.201041f), TypeParam(0.000000f), TypeParam(0.014502f), TypeParam(0.000000f), TypeParam(0.000000f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_5D_nearest_zeros_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_nearest_zeros_no_align_corners) { OpTester test("GridSample", 20); std::string mode = "nearest"; std::string padding_mode = "zeros"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 3, 2}; - std::initializer_list X_data{0.317302f, 0.629807f, -0.470444f, 0.215051f, 2.234212f, -1.940229f, 0.577203f, -0.166697f, -0.023467f, -0.451050f, -2.199999f, 1.469197f, -1.758133f, -0.570410f, -1.040355f, -0.627640f, 1.398573f, 0.275127f, -0.333592f, -0.677762f, -0.247167f, -0.290725f, -0.986956f, 0.173983f, -0.971920f, 0.225261f, -0.626680f, 1.660835f, 0.972993f, 0.223424f, 2.283593f, -1.145964f, -0.851223f, -2.052948f, -1.351783f, -0.028922f, 0.394421f, 0.057878f, -0.668671f, -0.088841f, 0.560186f, -0.105506f, 0.277478f, 1.047901f, -0.564728f, -0.287761f, 0.653621f, 0.259766f, 1.629452f, -2.337903f, -0.276703f, 0.258084f, -0.552200f, -0.464470f, -0.412042f, -1.047346f, 0.169468f, 1.334588f, 0.580615f, 1.217562f, -2.487876f, -1.218598f, -0.256617f, 1.397251f, 0.694875f, 0.732315f, 0.574448f, 0.673838f, -1.870634f, -0.855206f, 1.068415f, 0.096061f}; + std::initializer_list X_data{TypeParam(0.317302f), TypeParam(0.629807f), TypeParam(-0.470444f), TypeParam(0.215051f), TypeParam(2.234212f), TypeParam(-1.940229f), TypeParam(0.577203f), TypeParam(-0.166697f), TypeParam(-0.023467f), TypeParam(-0.451050f), TypeParam(-2.199999f), TypeParam(1.469197f), TypeParam(-1.758133f), TypeParam(-0.570410f), TypeParam(-1.040355f), TypeParam(-0.627640f), TypeParam(1.398573f), TypeParam(0.275127f), TypeParam(-0.333592f), TypeParam(-0.677762f), TypeParam(-0.247167f), TypeParam(-0.290725f), TypeParam(-0.986956f), TypeParam(0.173983f), TypeParam(-0.971920f), TypeParam(0.225261f), TypeParam(-0.626680f), TypeParam(1.660835f), TypeParam(0.972993f), TypeParam(0.223424f), TypeParam(2.283593f), TypeParam(-1.145964f), TypeParam(-0.851223f), TypeParam(-2.052948f), TypeParam(-1.351783f), TypeParam(-0.028922f), TypeParam(0.394421f), TypeParam(0.057878f), TypeParam(-0.668671f), TypeParam(-0.088841f), TypeParam(0.560186f), TypeParam(-0.105506f), TypeParam(0.277478f), TypeParam(1.047901f), TypeParam(-0.564728f), TypeParam(-0.287761f), TypeParam(0.653621f), TypeParam(0.259766f), TypeParam(1.629452f), TypeParam(-2.337903f), TypeParam(-0.276703f), TypeParam(0.258084f), TypeParam(-0.552200f), TypeParam(-0.464470f), TypeParam(-0.412042f), TypeParam(-1.047346f), TypeParam(0.169468f), TypeParam(1.334588f), TypeParam(0.580615f), TypeParam(1.217562f), TypeParam(-2.487876f), TypeParam(-1.218598f), TypeParam(-0.256617f), TypeParam(1.397251f), TypeParam(0.694875f), TypeParam(0.732315f), TypeParam(0.574448f), TypeParam(0.673838f), TypeParam(-1.870634f), TypeParam(-0.855206f), TypeParam(1.068415f), TypeParam(0.096061f)}; std::initializer_list Grid_shape{2, 3, 3, 2, 3}; - std::initializer_list Grid_data{0.650046f, -0.680891f, -0.200337f, -1.006178f, -0.676990f, 0.500592f, -1.118072f, -0.684288f, 0.899676f, -0.615418f, -0.499387f, -0.336929f, 0.512951f, -0.787164f, 0.120318f, 0.490083f, -0.087112f, 0.216982f, -0.915417f, 0.542519f, 0.448475f, -0.150519f, -0.992244f, 0.479971f, 0.783050f, -0.209890f, 0.565605f, 0.444791f, -0.479961f, -0.083304f, 1.194526f, 0.005665f, -0.955336f, -0.087514f, 0.596991f, -0.391708f, -0.628420f, 0.988534f, 0.634814f, -0.203871f, 0.061307f, -0.126915f, 0.278599f, 0.042647f, -0.726162f, 0.222329f, 0.031386f, 0.077584f, -0.457305f, 0.307467f, -0.970375f, 0.358708f, 0.650272f, -0.132064f, -0.932160f, -0.004362f, 0.001704f, -1.037046f, -0.848754f, 1.109926f, 0.897382f, 0.665044f, 0.831311f, 0.461956f, 0.675346f, 0.794786f, -0.280329f, -0.152546f, 0.855656f, -0.000432f, -0.780824f, -0.930479f, 0.671131f, 0.993983f, 0.931935f, 0.199703f, 0.828337f, -1.101760f, -0.864556f, -1.154677f, 0.966824f, -0.010858f, -0.552558f, 0.406048f, -0.449199f, -0.769613f, 0.462838f, 0.219719f, -0.859342f, -0.790394f, 0.562644f, 0.912452f, 0.097688f, -0.602742f, 0.579449f, 0.209287f, -1.050575f, -0.777654f, 0.262652f, 0.742529f, -0.385517f, 0.580240f, -0.743175f, 1.148320f, 0.855053f, 0.224769f, 0.533871f, 0.417788f}; + std::initializer_list Grid_data{TypeParam(0.650046f), TypeParam(-0.680891f), TypeParam(-0.200337f), TypeParam(-1.006178f), TypeParam(-0.676990f), TypeParam(0.500592f), TypeParam(-1.118072f), TypeParam(-0.684288f), TypeParam(0.899676f), TypeParam(-0.615418f), TypeParam(-0.499387f), TypeParam(-0.336929f), TypeParam(0.512951f), TypeParam(-0.787164f), TypeParam(0.120318f), TypeParam(0.490083f), TypeParam(-0.087112f), TypeParam(0.216982f), TypeParam(-0.915417f), TypeParam(0.542519f), TypeParam(0.448475f), TypeParam(-0.150519f), TypeParam(-0.992244f), TypeParam(0.479971f), TypeParam(0.783050f), TypeParam(-0.209890f), TypeParam(0.565605f), TypeParam(0.444791f), TypeParam(-0.479961f), TypeParam(-0.083304f), TypeParam(1.194526f), TypeParam(0.005665f), TypeParam(-0.955336f), TypeParam(-0.087514f), TypeParam(0.596991f), TypeParam(-0.391708f), TypeParam(-0.628420f), TypeParam(0.988534f), TypeParam(0.634814f), TypeParam(-0.203871f), TypeParam(0.061307f), TypeParam(-0.126915f), TypeParam(0.278599f), TypeParam(0.042647f), TypeParam(-0.726162f), TypeParam(0.222329f), TypeParam(0.031386f), TypeParam(0.077584f), TypeParam(-0.457305f), TypeParam(0.307467f), TypeParam(-0.970375f), TypeParam(0.358708f), TypeParam(0.650272f), TypeParam(-0.132064f), TypeParam(-0.932160f), TypeParam(-0.004362f), TypeParam(0.001704f), TypeParam(-1.037046f), TypeParam(-0.848754f), TypeParam(1.109926f), TypeParam(0.897382f), TypeParam(0.665044f), TypeParam(0.831311f), TypeParam(0.461956f), TypeParam(0.675346f), TypeParam(0.794786f), TypeParam(-0.280329f), TypeParam(-0.152546f), TypeParam(0.855656f), TypeParam(-0.000432f), TypeParam(-0.780824f), TypeParam(-0.930479f), TypeParam(0.671131f), TypeParam(0.993983f), TypeParam(0.931935f), TypeParam(0.199703f), TypeParam(0.828337f), TypeParam(-1.101760f), TypeParam(-0.864556f), TypeParam(-1.154677f), TypeParam(0.966824f), TypeParam(-0.010858f), TypeParam(-0.552558f), TypeParam(0.406048f), TypeParam(-0.449199f), TypeParam(-0.769613f), TypeParam(0.462838f), TypeParam(0.219719f), TypeParam(-0.859342f), TypeParam(-0.790394f), TypeParam(0.562644f), TypeParam(0.912452f), TypeParam(0.097688f), TypeParam(-0.602742f), TypeParam(0.579449f), TypeParam(0.209287f), TypeParam(-1.050575f), TypeParam(-0.777654f), TypeParam(0.262652f), TypeParam(0.742529f), TypeParam(-0.385517f), TypeParam(0.580240f), TypeParam(-0.743175f), TypeParam(1.148320f), TypeParam(0.855053f), TypeParam(0.224769f), TypeParam(0.533871f), TypeParam(0.417788f)}; std::initializer_list Y_shape{2, 2, 3, 3, 2}; - std::initializer_list Y_data{-0.166697f, 0.000000f, 0.000000f, 0.317302f, -0.166697f, -0.451050f, 1.398573f, -1.758133f, -0.627640f, -0.166697f, 0.000000f, 2.234212f, 1.398573f, -0.023467f, 0.215051f, -0.451050f, -0.470444f, 1.469197f, 0.225261f, 0.000000f, 0.000000f, -0.333592f, 0.225261f, 1.660835f, -1.351783f, 2.283593f, -2.052948f, 0.225261f, 0.000000f, -0.986956f, -1.351783f, -0.626680f, -0.290725f, 1.660835f, -0.247167f, 0.223424f, -0.564728f, 0.000000f, -0.464470f, -0.464470f, -0.276703f, 0.394421f, -0.464470f, 0.000000f, 0.000000f, 1.629452f, 1.629452f, 0.057878f, 0.259766f, 0.653621f, 0.000000f, -2.337903f, 0.000000f, -0.464470f, -0.256617f, 0.000000f, 0.096061f, 0.096061f, -1.870634f, -0.412042f, 0.096061f, 0.000000f, 0.000000f, 0.574448f, 0.574448f, -1.047346f, 0.732315f, 0.694875f, 0.000000f, 0.673838f, 0.000000f, 0.096061f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.166697f), TypeParam(0.000000f), TypeParam(0.000000f), TypeParam(0.317302f), TypeParam(-0.166697f), TypeParam(-0.451050f), TypeParam(1.398573f), TypeParam(-1.758133f), TypeParam(-0.627640f), TypeParam(-0.166697f), TypeParam(0.000000f), TypeParam(2.234212f), TypeParam(1.398573f), TypeParam(-0.023467f), TypeParam(0.215051f), TypeParam(-0.451050f), TypeParam(-0.470444f), TypeParam(1.469197f), TypeParam(0.225261f), TypeParam(0.000000f), TypeParam(0.000000f), TypeParam(-0.333592f), TypeParam(0.225261f), TypeParam(1.660835f), TypeParam(-1.351783f), TypeParam(2.283593f), TypeParam(-2.052948f), TypeParam(0.225261f), TypeParam(0.000000f), TypeParam(-0.986956f), TypeParam(-1.351783f), TypeParam(-0.626680f), TypeParam(-0.290725f), TypeParam(1.660835f), TypeParam(-0.247167f), TypeParam(0.223424f), TypeParam(-0.564728f), TypeParam(0.000000f), TypeParam(-0.464470f), TypeParam(-0.464470f), TypeParam(-0.276703f), TypeParam(0.394421f), TypeParam(-0.464470f), TypeParam(0.000000f), TypeParam(0.000000f), TypeParam(1.629452f), TypeParam(1.629452f), TypeParam(0.057878f), TypeParam(0.259766f), TypeParam(0.653621f), TypeParam(0.000000f), TypeParam(-2.337903f), TypeParam(0.000000f), TypeParam(-0.464470f), TypeParam(-0.256617f), TypeParam(0.000000f), TypeParam(0.096061f), TypeParam(0.096061f), TypeParam(-1.870634f), TypeParam(-0.412042f), TypeParam(0.096061f), TypeParam(0.000000f), TypeParam(0.000000f), TypeParam(0.574448f), TypeParam(0.574448f), TypeParam(-1.047346f), TypeParam(0.732315f), TypeParam(0.694875f), TypeParam(0.000000f), TypeParam(0.673838f), TypeParam(0.000000f), TypeParam(0.096061f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_4D_nearest_border_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_nearest_border_align_corners) { OpTester test("GridSample", 20); std::string mode = "nearest"; std::string padding_mode = "border"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{0.660065f, 0.995767f, -0.226389f, 0.590604f, -2.628610f, 0.444899f, 0.023282f, 0.024018f, -0.584701f, 1.988638f, -0.023379f, 0.711650f, -1.062933f, -0.064113f, 1.178346f, -0.652373f, 1.259795f, 1.508661f, -0.079368f, 0.819443f, 0.836356f, -0.362184f, -1.153828f, -0.561180f}; + std::initializer_list X_data{TypeParam(0.660065f), TypeParam(0.995767f), TypeParam(-0.226389f), TypeParam(0.590604f), TypeParam(-2.628610f), TypeParam(0.444899f), TypeParam(0.023282f), TypeParam(0.024018f), TypeParam(-0.584701f), TypeParam(1.988638f), TypeParam(-0.023379f), TypeParam(0.711650f), TypeParam(-1.062933f), TypeParam(-0.064113f), TypeParam(1.178346f), TypeParam(-0.652373f), TypeParam(1.259795f), TypeParam(1.508661f), TypeParam(-0.079368f), TypeParam(0.819443f), TypeParam(0.836356f), TypeParam(-0.362184f), TypeParam(-1.153828f), TypeParam(-0.561180f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{-0.447651f, -0.521958f, 0.673539f, 0.222645f, 1.010165f, 0.451903f, 0.966699f, -0.966970f, 0.964714f, -0.551345f, -0.321222f, 0.007182f, -0.225038f, 0.237367f, 1.069316f, -0.716982f, 0.370785f, -0.964445f, 0.188419f, 0.988574f, 0.809140f, 1.027635f, 0.649589f, -0.099282f}; + std::initializer_list Grid_data{TypeParam(-0.447651f), TypeParam(-0.521958f), TypeParam(0.673539f), TypeParam(0.222645f), TypeParam(1.010165f), TypeParam(0.451903f), TypeParam(0.966699f), TypeParam(-0.966970f), TypeParam(0.964714f), TypeParam(-0.551345f), TypeParam(-0.321222f), TypeParam(0.007182f), TypeParam(-0.225038f), TypeParam(0.237367f), TypeParam(1.069316f), TypeParam(-0.716982f), TypeParam(0.370785f), TypeParam(-0.964445f), TypeParam(0.188419f), TypeParam(0.988574f), TypeParam(0.809140f), TypeParam(1.027635f), TypeParam(0.649589f), TypeParam(-0.099282f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{0.660065f, 0.590604f, 0.590604f, 0.995767f, 0.995767f, -0.226389f, 0.023282f, 1.988638f, 1.988638f, 0.024018f, 0.024018f, -0.584701f, 1.178346f, -0.064113f, -0.064113f, 1.508661f, 1.508661f, -0.652373f, 0.836356f, 0.819443f, 0.819443f, -0.561180f, -0.561180f, -0.362184f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(0.660065f), TypeParam(0.590604f), TypeParam(0.590604f), TypeParam(0.995767f), TypeParam(0.995767f), TypeParam(-0.226389f), TypeParam(0.023282f), TypeParam(1.988638f), TypeParam(1.988638f), TypeParam(0.024018f), TypeParam(0.024018f), TypeParam(-0.584701f), TypeParam(1.178346f), TypeParam(-0.064113f), TypeParam(-0.064113f), TypeParam(1.508661f), TypeParam(1.508661f), TypeParam(-0.652373f), TypeParam(0.836356f), TypeParam(0.819443f), TypeParam(0.819443f), TypeParam(-0.561180f), TypeParam(-0.561180f), TypeParam(-0.362184f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_5D_nearest_border_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_nearest_border_align_corners) { OpTester test("GridSample", 20); std::string mode = "nearest"; std::string padding_mode = "border"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 3, 2}; - std::initializer_list X_data{-0.920922f, -0.560469f, -2.244605f, -0.061799f, 0.523656f, 0.110097f, -0.944521f, 0.818932f, 1.069286f, 0.611457f, -0.355875f, 1.664810f, 0.116694f, 2.318200f, 0.681699f, -0.792880f, -0.025672f, -0.592222f, 0.229768f, -0.521888f, 0.570937f, -0.029345f, -0.873323f, 1.721509f, 2.011626f, -0.310838f, 1.121670f, 0.778967f, -0.450894f, 1.030269f, 0.166967f, -0.244737f, 0.227200f, -0.416612f, -0.276513f, 0.714623f, 0.908783f, -1.393580f, -0.983675f, -0.366833f, 1.473970f, 0.624368f, -0.607720f, -0.523833f, -0.124702f, -0.766457f, -0.131027f, 2.227047f, 1.399269f, 0.053366f, -0.295771f, -0.283811f, 0.019280f, -0.104450f, -0.574185f, -2.130628f, 0.617878f, -1.728151f, -0.272528f, 1.299354f, -1.109310f, -1.881107f, -1.300843f, -0.765376f, -0.477722f, -1.230664f, -0.495792f, 1.061688f, 1.244247f, -0.550821f, -0.520524f, 1.541448f}; + std::initializer_list X_data{TypeParam(-0.920922f), TypeParam(-0.560469f), TypeParam(-2.244605f), TypeParam(-0.061799f), TypeParam(0.523656f), TypeParam(0.110097f), TypeParam(-0.944521f), TypeParam(0.818932f), TypeParam(1.069286f), TypeParam(0.611457f), TypeParam(-0.355875f), TypeParam(1.664810f), TypeParam(0.116694f), TypeParam(2.318200f), TypeParam(0.681699f), TypeParam(-0.792880f), TypeParam(-0.025672f), TypeParam(-0.592222f), TypeParam(0.229768f), TypeParam(-0.521888f), TypeParam(0.570937f), TypeParam(-0.029345f), TypeParam(-0.873323f), TypeParam(1.721509f), TypeParam(2.011626f), TypeParam(-0.310838f), TypeParam(1.121670f), TypeParam(0.778967f), TypeParam(-0.450894f), TypeParam(1.030269f), TypeParam(0.166967f), TypeParam(-0.244737f), TypeParam(0.227200f), TypeParam(-0.416612f), TypeParam(-0.276513f), TypeParam(0.714623f), TypeParam(0.908783f), TypeParam(-1.393580f), TypeParam(-0.983675f), TypeParam(-0.366833f), TypeParam(1.473970f), TypeParam(0.624368f), TypeParam(-0.607720f), TypeParam(-0.523833f), TypeParam(-0.124702f), TypeParam(-0.766457f), TypeParam(-0.131027f), TypeParam(2.227047f), TypeParam(1.399269f), TypeParam(0.053366f), TypeParam(-0.295771f), TypeParam(-0.283811f), TypeParam(0.019280f), TypeParam(-0.104450f), TypeParam(-0.574185f), TypeParam(-2.130628f), TypeParam(0.617878f), TypeParam(-1.728151f), TypeParam(-0.272528f), TypeParam(1.299354f), TypeParam(-1.109310f), TypeParam(-1.881107f), TypeParam(-1.300843f), TypeParam(-0.765376f), TypeParam(-0.477722f), TypeParam(-1.230664f), TypeParam(-0.495792f), TypeParam(1.061688f), TypeParam(1.244247f), TypeParam(-0.550821f), TypeParam(-0.520524f), TypeParam(1.541448f)}; std::initializer_list Grid_shape{2, 3, 3, 2, 3}; - std::initializer_list Grid_data{-1.189605f, -0.312072f, 0.459409f, 1.033285f, -1.083635f, 0.572921f, -1.138649f, -1.147562f, -0.751493f, -0.158500f, 0.335153f, -0.912613f, 0.924528f, 1.085165f, 0.073832f, 0.976781f, -0.543258f, -0.474714f, -0.154854f, 0.131118f, -0.837104f, -0.960885f, 0.474040f, 0.345992f, 1.173923f, -0.489256f, 0.423768f, -0.484246f, 0.592379f, -0.066474f, 0.889570f, 0.666682f, 0.998817f, 0.616675f, 0.045084f, 1.034127f, -0.704858f, 1.131824f, 1.172625f, 1.146321f, -0.560545f, -0.635830f, 0.075922f, 0.373677f, 0.601953f, 0.488043f, 1.021787f, -0.300648f, -0.393688f, 0.402240f, 0.334401f, -0.699993f, 0.116070f, -0.911100f, -0.352043f, -0.470968f, 1.051900f, -1.080208f, -0.708510f, -1.174356f, 0.302647f, -0.923627f, 0.388249f, -0.833533f, -0.768697f, -0.613051f, 0.180083f, 1.102657f, 1.124055f, -0.090660f, -1.175396f, -0.396450f, -0.457333f, -0.255235f, 0.458506f, 0.603882f, 0.532050f, 0.342802f, -0.485794f, -0.012730f, 0.152721f, -0.612948f, -0.107348f, -0.149795f, -1.133775f, 0.813507f, -0.121323f, -1.037352f, 0.949408f, -0.645689f, 0.424853f, 1.190055f, 0.055551f, 0.345244f, 0.476794f, 0.906949f, -0.368187f, -0.675263f, -0.093908f, 0.938461f, 0.103178f, 0.833774f, -0.008922f, 0.368184f, 0.041727f, 0.032575f, -1.141943f, -1.049081f}; + std::initializer_list Grid_data{TypeParam(-1.189605f), TypeParam(-0.312072f), TypeParam(0.459409f), TypeParam(1.033285f), TypeParam(-1.083635f), TypeParam(0.572921f), TypeParam(-1.138649f), TypeParam(-1.147562f), TypeParam(-0.751493f), TypeParam(-0.158500f), TypeParam(0.335153f), TypeParam(-0.912613f), TypeParam(0.924528f), TypeParam(1.085165f), TypeParam(0.073832f), TypeParam(0.976781f), TypeParam(-0.543258f), TypeParam(-0.474714f), TypeParam(-0.154854f), TypeParam(0.131118f), TypeParam(-0.837104f), TypeParam(-0.960885f), TypeParam(0.474040f), TypeParam(0.345992f), TypeParam(1.173923f), TypeParam(-0.489256f), TypeParam(0.423768f), TypeParam(-0.484246f), TypeParam(0.592379f), TypeParam(-0.066474f), TypeParam(0.889570f), TypeParam(0.666682f), TypeParam(0.998817f), TypeParam(0.616675f), TypeParam(0.045084f), TypeParam(1.034127f), TypeParam(-0.704858f), TypeParam(1.131824f), TypeParam(1.172625f), TypeParam(1.146321f), TypeParam(-0.560545f), TypeParam(-0.635830f), TypeParam(0.075922f), TypeParam(0.373677f), TypeParam(0.601953f), TypeParam(0.488043f), TypeParam(1.021787f), TypeParam(-0.300648f), TypeParam(-0.393688f), TypeParam(0.402240f), TypeParam(0.334401f), TypeParam(-0.699993f), TypeParam(0.116070f), TypeParam(-0.911100f), TypeParam(-0.352043f), TypeParam(-0.470968f), TypeParam(1.051900f), TypeParam(-1.080208f), TypeParam(-0.708510f), TypeParam(-1.174356f), TypeParam(0.302647f), TypeParam(-0.923627f), TypeParam(0.388249f), TypeParam(-0.833533f), TypeParam(-0.768697f), TypeParam(-0.613051f), TypeParam(0.180083f), TypeParam(1.102657f), TypeParam(1.124055f), TypeParam(-0.090660f), TypeParam(-1.175396f), TypeParam(-0.396450f), TypeParam(-0.457333f), TypeParam(-0.255235f), TypeParam(0.458506f), TypeParam(0.603882f), TypeParam(0.532050f), TypeParam(0.342802f), TypeParam(-0.485794f), TypeParam(-0.012730f), TypeParam(0.152721f), TypeParam(-0.612948f), TypeParam(-0.107348f), TypeParam(-0.149795f), TypeParam(-1.133775f), TypeParam(0.813507f), TypeParam(-0.121323f), TypeParam(-1.037352f), TypeParam(0.949408f), TypeParam(-0.645689f), TypeParam(0.424853f), TypeParam(1.190055f), TypeParam(0.055551f), TypeParam(0.345244f), TypeParam(0.476794f), TypeParam(0.906949f), TypeParam(-0.368187f), TypeParam(-0.675263f), TypeParam(-0.093908f), TypeParam(0.938461f), TypeParam(0.103178f), TypeParam(0.833774f), TypeParam(-0.008922f), TypeParam(0.368184f), TypeParam(0.041727f), TypeParam(0.032575f), TypeParam(-1.141943f), TypeParam(-1.049081f)}; std::initializer_list Y_shape{2, 2, 3, 3, 2}; - std::initializer_list Y_data{1.069286f, 2.318200f, -0.920922f, -2.244605f, 1.664810f, 0.818932f, -2.244605f, 1.069286f, 0.611457f, -0.355875f, -0.592222f, -0.792880f, -0.025672f, -0.560469f, -0.792880f, 1.664810f, 1.069286f, -2.244605f, 1.121670f, -0.244737f, 0.229768f, 0.570937f, 1.030269f, -0.310838f, 0.570937f, 1.121670f, 0.778967f, -0.450894f, 0.714623f, -0.416612f, -0.276513f, -0.521888f, -0.416612f, 1.030269f, 1.121670f, 0.570937f, -0.295771f, 0.908783f, -0.523833f, 0.908783f, -0.104450f, -0.607720f, -0.124702f, 2.227047f, -0.124702f, -0.124702f, -0.131027f, 1.473970f, 2.227047f, -0.283811f, -0.607720f, -0.283811f, -0.124702f, -1.393580f, 1.244247f, -0.574185f, -1.881107f, -0.574185f, 1.541448f, -1.109310f, -1.300843f, -1.230664f, -1.300843f, -1.300843f, -0.477722f, -0.272528f, -1.230664f, -0.550821f, -1.109310f, -0.550821f, -1.300843f, -2.130628f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(1.069286f), TypeParam(2.318200f), TypeParam(-0.920922f), TypeParam(-2.244605f), TypeParam(1.664810f), TypeParam(0.818932f), TypeParam(-2.244605f), TypeParam(1.069286f), TypeParam(0.611457f), TypeParam(-0.355875f), TypeParam(-0.592222f), TypeParam(-0.792880f), TypeParam(-0.025672f), TypeParam(-0.560469f), TypeParam(-0.792880f), TypeParam(1.664810f), TypeParam(1.069286f), TypeParam(-2.244605f), TypeParam(1.121670f), TypeParam(-0.244737f), TypeParam(0.229768f), TypeParam(0.570937f), TypeParam(1.030269f), TypeParam(-0.310838f), TypeParam(0.570937f), TypeParam(1.121670f), TypeParam(0.778967f), TypeParam(-0.450894f), TypeParam(0.714623f), TypeParam(-0.416612f), TypeParam(-0.276513f), TypeParam(-0.521888f), TypeParam(-0.416612f), TypeParam(1.030269f), TypeParam(1.121670f), TypeParam(0.570937f), TypeParam(-0.295771f), TypeParam(0.908783f), TypeParam(-0.523833f), TypeParam(0.908783f), TypeParam(-0.104450f), TypeParam(-0.607720f), TypeParam(-0.124702f), TypeParam(2.227047f), TypeParam(-0.124702f), TypeParam(-0.124702f), TypeParam(-0.131027f), TypeParam(1.473970f), TypeParam(2.227047f), TypeParam(-0.283811f), TypeParam(-0.607720f), TypeParam(-0.283811f), TypeParam(-0.124702f), TypeParam(-1.393580f), TypeParam(1.244247f), TypeParam(-0.574185f), TypeParam(-1.881107f), TypeParam(-0.574185f), TypeParam(1.541448f), TypeParam(-1.109310f), TypeParam(-1.300843f), TypeParam(-1.230664f), TypeParam(-1.300843f), TypeParam(-1.300843f), TypeParam(-0.477722f), TypeParam(-0.272528f), TypeParam(-1.230664f), TypeParam(-0.550821f), TypeParam(-1.109310f), TypeParam(-0.550821f), TypeParam(-1.300843f), TypeParam(-2.130628f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_4D_nearest_border_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_nearest_border_no_align_corners) { OpTester test("GridSample", 20); std::string mode = "nearest"; std::string padding_mode = "border"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{0.950589f, -1.656624f, 0.767704f, -0.650720f, -1.404308f, -0.531582f, -0.280854f, 0.344309f, -0.959146f, -0.115645f, 0.515696f, -0.114243f, 1.971614f, 0.274268f, 0.543080f, -1.758563f, 1.771011f, 0.934901f, 0.695798f, 1.905137f, 1.598307f, 1.108385f, 0.156008f, 1.290824f}; + std::initializer_list X_data{TypeParam(0.950589f), TypeParam(-1.656624f), TypeParam(0.767704f), TypeParam(-0.650720f), TypeParam(-1.404308f), TypeParam(-0.531582f), TypeParam(-0.280854f), TypeParam(0.344309f), TypeParam(-0.959146f), TypeParam(-0.115645f), TypeParam(0.515696f), TypeParam(-0.114243f), TypeParam(1.971614f), TypeParam(0.274268f), TypeParam(0.543080f), TypeParam(-1.758563f), TypeParam(1.771011f), TypeParam(0.934901f), TypeParam(0.695798f), TypeParam(1.905137f), TypeParam(1.598307f), TypeParam(1.108385f), TypeParam(0.156008f), TypeParam(1.290824f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.482490f, -0.910951f, -0.001676f, -0.442514f, 0.580438f, 1.039346f, -0.159076f, -0.603960f, -0.922037f, -0.705026f, 0.346468f, 0.275332f, 0.646235f, -0.178307f, 0.616600f, -1.069108f, 0.322583f, 1.164952f, -1.187638f, -0.622953f, 0.768203f, -0.187618f, -0.639652f, 0.732078f}; + std::initializer_list Grid_data{TypeParam(0.482490f), TypeParam(-0.910951f), TypeParam(-0.001676f), TypeParam(-0.442514f), TypeParam(0.580438f), TypeParam(1.039346f), TypeParam(-0.159076f), TypeParam(-0.603960f), TypeParam(-0.922037f), TypeParam(-0.705026f), TypeParam(0.346468f), TypeParam(0.275332f), TypeParam(0.646235f), TypeParam(-0.178307f), TypeParam(0.616600f), TypeParam(-1.069108f), TypeParam(0.322583f), TypeParam(1.164952f), TypeParam(-1.187638f), TypeParam(-0.622953f), TypeParam(0.768203f), TypeParam(-0.187618f), TypeParam(-0.639652f), TypeParam(0.732078f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-1.656624f, 0.950589f, -0.531582f, 0.950589f, 0.950589f, -0.650720f, 0.344309f, -0.280854f, -0.114243f, -0.280854f, -0.280854f, -0.115645f, -1.758563f, 0.274268f, 0.934901f, 1.971614f, -1.758563f, 1.771011f, 1.108385f, 1.905137f, 1.290824f, 0.695798f, 1.108385f, 0.156008f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-1.656624f), TypeParam(0.950589f), TypeParam(-0.531582f), TypeParam(0.950589f), TypeParam(0.950589f), TypeParam(-0.650720f), TypeParam(0.344309f), TypeParam(-0.280854f), TypeParam(-0.114243f), TypeParam(-0.280854f), TypeParam(-0.280854f), TypeParam(-0.115645f), TypeParam(-1.758563f), TypeParam(0.274268f), TypeParam(0.934901f), TypeParam(1.971614f), TypeParam(-1.758563f), TypeParam(1.771011f), TypeParam(1.108385f), TypeParam(1.905137f), TypeParam(1.290824f), TypeParam(0.695798f), TypeParam(1.108385f), TypeParam(0.156008f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_5D_nearest_border_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_nearest_border_no_align_corners) { OpTester test("GridSample", 20); std::string mode = "nearest"; std::string padding_mode = "border"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 3, 2}; - std::initializer_list X_data{0.465448f, -0.337086f, -0.870849f, -0.389573f, -0.083941f, 1.306894f, 0.719508f, -0.203690f, -1.143864f, 1.163003f, 0.312170f, -2.008687f, 1.731257f, -0.270431f, 1.095352f, -1.673520f, 0.492743f, 0.521962f, -1.938783f, -0.186813f, -0.836257f, -1.835450f, 0.476500f, -0.123386f, 0.246604f, 1.374159f, -0.158435f, 1.268192f, -0.704226f, -0.195314f, -0.277259f, 0.582961f, -0.340940f, 0.192264f, 0.463124f, -2.719402f, -0.593470f, -1.165777f, 0.566071f, 1.622836f, -0.886798f, 1.874877f, -0.849095f, 0.550185f, 0.604298f, 0.073976f, -0.800372f, -0.097283f, -1.576251f, -0.633278f, -1.776745f, -0.827586f, 0.665697f, 0.884698f, 0.467112f, -0.645219f, -0.510110f, 0.032418f, -1.056009f, -0.206175f, -0.173385f, 0.947787f, 1.937234f, 0.615880f, -0.311580f, 0.770921f, -0.841602f, 1.796220f, 0.479491f, 1.609346f, 1.113868f, -0.453360f}; + std::initializer_list X_data{TypeParam(0.465448f), TypeParam(-0.337086f), TypeParam(-0.870849f), TypeParam(-0.389573f), TypeParam(-0.083941f), TypeParam(1.306894f), TypeParam(0.719508f), TypeParam(-0.203690f), TypeParam(-1.143864f), TypeParam(1.163003f), TypeParam(0.312170f), TypeParam(-2.008687f), TypeParam(1.731257f), TypeParam(-0.270431f), TypeParam(1.095352f), TypeParam(-1.673520f), TypeParam(0.492743f), TypeParam(0.521962f), TypeParam(-1.938783f), TypeParam(-0.186813f), TypeParam(-0.836257f), TypeParam(-1.835450f), TypeParam(0.476500f), TypeParam(-0.123386f), TypeParam(0.246604f), TypeParam(1.374159f), TypeParam(-0.158435f), TypeParam(1.268192f), TypeParam(-0.704226f), TypeParam(-0.195314f), TypeParam(-0.277259f), TypeParam(0.582961f), TypeParam(-0.340940f), TypeParam(0.192264f), TypeParam(0.463124f), TypeParam(-2.719402f), TypeParam(-0.593470f), TypeParam(-1.165777f), TypeParam(0.566071f), TypeParam(1.622836f), TypeParam(-0.886798f), TypeParam(1.874877f), TypeParam(-0.849095f), TypeParam(0.550185f), TypeParam(0.604298f), TypeParam(0.073976f), TypeParam(-0.800372f), TypeParam(-0.097283f), TypeParam(-1.576251f), TypeParam(-0.633278f), TypeParam(-1.776745f), TypeParam(-0.827586f), TypeParam(0.665697f), TypeParam(0.884698f), TypeParam(0.467112f), TypeParam(-0.645219f), TypeParam(-0.510110f), TypeParam(0.032418f), TypeParam(-1.056009f), TypeParam(-0.206175f), TypeParam(-0.173385f), TypeParam(0.947787f), TypeParam(1.937234f), TypeParam(0.615880f), TypeParam(-0.311580f), TypeParam(0.770921f), TypeParam(-0.841602f), TypeParam(1.796220f), TypeParam(0.479491f), TypeParam(1.609346f), TypeParam(1.113868f), TypeParam(-0.453360f)}; std::initializer_list Grid_shape{2, 3, 3, 2, 3}; - std::initializer_list Grid_data{-0.151540f, -0.033291f, -0.597203f, 0.836404f, -0.686848f, -0.485355f, -0.936738f, -1.009057f, 1.065352f, -0.926635f, -0.165670f, -0.347352f, 0.439545f, 0.320963f, -0.919909f, 1.077689f, -1.195359f, 0.118687f, -0.100253f, -0.278089f, 0.817760f, 1.013180f, 0.156316f, -0.423839f, 0.892139f, 0.753924f, 0.215530f, -0.328214f, 0.050592f, 1.069553f, 0.130134f, -0.236478f, -1.015986f, -0.643059f, 0.866682f, -0.042256f, -0.079912f, 0.467233f, -0.789513f, -0.081063f, -0.337505f, 0.627865f, 0.976589f, 0.753489f, 0.894667f, -1.072442f, -0.426020f, 0.142099f, -1.019226f, 0.325527f, -0.786578f, 0.514215f, 0.971223f, -1.026539f, 1.005531f, 0.559922f, -0.791906f, 1.148613f, -1.039306f, -0.807864f, -0.596935f, -0.060766f, 0.215484f, -0.352165f, -1.137417f, -0.138518f, 0.910459f, 0.923925f, 0.600710f, 0.174227f, 0.298169f, -0.925092f, 0.485927f, -1.194283f, -0.495564f, -0.315357f, 0.881199f, -0.034981f, -0.546611f, 0.209651f, -0.995724f, -0.317709f, 0.332343f, -0.079474f, -0.126024f, 0.733410f, -0.911554f, -0.605911f, 1.161566f, 0.238787f, -0.194293f, 0.621583f, 0.721901f, -0.200521f, -0.499850f, -0.196149f, 0.435730f, -0.153196f, 0.698401f, -0.978582f, -0.588758f, 0.914808f, 0.157427f, 0.241646f, 0.394674f, -0.283552f, -0.479889f, 0.344261f}; + std::initializer_list Grid_data{TypeParam(-0.151540f), TypeParam(-0.033291f), TypeParam(-0.597203f), TypeParam(0.836404f), TypeParam(-0.686848f), TypeParam(-0.485355f), TypeParam(-0.936738f), TypeParam(-1.009057f), TypeParam(1.065352f), TypeParam(-0.926635f), TypeParam(-0.165670f), TypeParam(-0.347352f), TypeParam(0.439545f), TypeParam(0.320963f), TypeParam(-0.919909f), TypeParam(1.077689f), TypeParam(-1.195359f), TypeParam(0.118687f), TypeParam(-0.100253f), TypeParam(-0.278089f), TypeParam(0.817760f), TypeParam(1.013180f), TypeParam(0.156316f), TypeParam(-0.423839f), TypeParam(0.892139f), TypeParam(0.753924f), TypeParam(0.215530f), TypeParam(-0.328214f), TypeParam(0.050592f), TypeParam(1.069553f), TypeParam(0.130134f), TypeParam(-0.236478f), TypeParam(-1.015986f), TypeParam(-0.643059f), TypeParam(0.866682f), TypeParam(-0.042256f), TypeParam(-0.079912f), TypeParam(0.467233f), TypeParam(-0.789513f), TypeParam(-0.081063f), TypeParam(-0.337505f), TypeParam(0.627865f), TypeParam(0.976589f), TypeParam(0.753489f), TypeParam(0.894667f), TypeParam(-1.072442f), TypeParam(-0.426020f), TypeParam(0.142099f), TypeParam(-1.019226f), TypeParam(0.325527f), TypeParam(-0.786578f), TypeParam(0.514215f), TypeParam(0.971223f), TypeParam(-1.026539f), TypeParam(1.005531f), TypeParam(0.559922f), TypeParam(-0.791906f), TypeParam(1.148613f), TypeParam(-1.039306f), TypeParam(-0.807864f), TypeParam(-0.596935f), TypeParam(-0.060766f), TypeParam(0.215484f), TypeParam(-0.352165f), TypeParam(-1.137417f), TypeParam(-0.138518f), TypeParam(0.910459f), TypeParam(0.923925f), TypeParam(0.600710f), TypeParam(0.174227f), TypeParam(0.298169f), TypeParam(-0.925092f), TypeParam(0.485927f), TypeParam(-1.194283f), TypeParam(-0.495564f), TypeParam(-0.315357f), TypeParam(0.881199f), TypeParam(-0.034981f), TypeParam(-0.546611f), TypeParam(0.209651f), TypeParam(-0.995724f), TypeParam(-0.317709f), TypeParam(0.332343f), TypeParam(-0.079474f), TypeParam(-0.126024f), TypeParam(0.733410f), TypeParam(-0.911554f), TypeParam(-0.605911f), TypeParam(1.161566f), TypeParam(0.238787f), TypeParam(-0.194293f), TypeParam(0.621583f), TypeParam(0.721901f), TypeParam(-0.200521f), TypeParam(-0.499850f), TypeParam(-0.196149f), TypeParam(0.435730f), TypeParam(-0.153196f), TypeParam(0.698401f), TypeParam(-0.978582f), TypeParam(-0.588758f), TypeParam(0.914808f), TypeParam(0.157427f), TypeParam(0.241646f), TypeParam(0.394674f), TypeParam(-0.283552f), TypeParam(-0.479889f), TypeParam(0.344261f)}; std::initializer_list Y_shape{2, 2, 3, 3, 2}; - std::initializer_list Y_data{-0.870849f, -0.337086f, 1.731257f, -0.870849f, -0.389573f, -0.203690f, 1.095352f, -0.389573f, -2.008687f, 1.095352f, -0.389573f, 0.312170f, -0.083941f, 1.731257f, 0.521962f, 0.719508f, -0.870849f, 1.306894f, -0.836257f, -0.186813f, -0.277259f, -0.836257f, -1.835450f, 1.374159f, -0.340940f, -1.835450f, -0.195314f, -0.340940f, -1.835450f, -0.704226f, 0.476500f, -0.277259f, -2.719402f, 0.246604f, -0.836257f, -0.123386f, 1.874877f, -1.165777f, 0.604298f, -0.849095f, 0.884698f, 1.622836f, -1.165777f, -0.800372f, 0.566071f, 0.604298f, -0.886798f, -0.800372f, 0.665697f, -0.849095f, -0.827586f, -1.576251f, -0.827586f, -1.576251f, -0.206175f, -0.645219f, 1.937234f, -0.173385f, -0.453360f, 0.032418f, -0.645219f, -0.311580f, -0.510110f, 1.937234f, -1.056009f, -0.311580f, 1.113868f, -0.173385f, 1.609346f, -0.841602f, 1.609346f, -0.841602f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.870849f), TypeParam(-0.337086f), TypeParam(1.731257f), TypeParam(-0.870849f), TypeParam(-0.389573f), TypeParam(-0.203690f), TypeParam(1.095352f), TypeParam(-0.389573f), TypeParam(-2.008687f), TypeParam(1.095352f), TypeParam(-0.389573f), TypeParam(0.312170f), TypeParam(-0.083941f), TypeParam(1.731257f), TypeParam(0.521962f), TypeParam(0.719508f), TypeParam(-0.870849f), TypeParam(1.306894f), TypeParam(-0.836257f), TypeParam(-0.186813f), TypeParam(-0.277259f), TypeParam(-0.836257f), TypeParam(-1.835450f), TypeParam(1.374159f), TypeParam(-0.340940f), TypeParam(-1.835450f), TypeParam(-0.195314f), TypeParam(-0.340940f), TypeParam(-1.835450f), TypeParam(-0.704226f), TypeParam(0.476500f), TypeParam(-0.277259f), TypeParam(-2.719402f), TypeParam(0.246604f), TypeParam(-0.836257f), TypeParam(-0.123386f), TypeParam(1.874877f), TypeParam(-1.165777f), TypeParam(0.604298f), TypeParam(-0.849095f), TypeParam(0.884698f), TypeParam(1.622836f), TypeParam(-1.165777f), TypeParam(-0.800372f), TypeParam(0.566071f), TypeParam(0.604298f), TypeParam(-0.886798f), TypeParam(-0.800372f), TypeParam(0.665697f), TypeParam(-0.849095f), TypeParam(-0.827586f), TypeParam(-1.576251f), TypeParam(-0.827586f), TypeParam(-1.576251f), TypeParam(-0.206175f), TypeParam(-0.645219f), TypeParam(1.937234f), TypeParam(-0.173385f), TypeParam(-0.453360f), TypeParam(0.032418f), TypeParam(-0.645219f), TypeParam(-0.311580f), TypeParam(-0.510110f), TypeParam(1.937234f), TypeParam(-1.056009f), TypeParam(-0.311580f), TypeParam(1.113868f), TypeParam(-0.173385f), TypeParam(1.609346f), TypeParam(-0.841602f), TypeParam(1.609346f), TypeParam(-0.841602f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_4D_nearest_reflection_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_nearest_reflection_align_corners) { OpTester test("GridSample", 20); std::string mode = "nearest"; std::string padding_mode = "reflection"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{0.079043f, 0.407494f, 1.038992f, -0.437542f, 0.991216f, 0.409636f, 1.050403f, -0.687172f, -2.021689f, 0.789633f, 0.538178f, 0.414847f, 2.221617f, -0.254833f, -0.179968f, -0.952356f, -1.213159f, 0.499103f, -0.374865f, 0.441938f, -0.114847f, 0.716887f, 1.059090f, 0.438870f}; + std::initializer_list X_data{TypeParam(0.079043f), TypeParam(0.407494f), TypeParam(1.038992f), TypeParam(-0.437542f), TypeParam(0.991216f), TypeParam(0.409636f), TypeParam(1.050403f), TypeParam(-0.687172f), TypeParam(-2.021689f), TypeParam(0.789633f), TypeParam(0.538178f), TypeParam(0.414847f), TypeParam(2.221617f), TypeParam(-0.254833f), TypeParam(-0.179968f), TypeParam(-0.952356f), TypeParam(-1.213159f), TypeParam(0.499103f), TypeParam(-0.374865f), TypeParam(0.441938f), TypeParam(-0.114847f), TypeParam(0.716887f), TypeParam(1.059090f), TypeParam(0.438870f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.355147f, -0.222342f, -1.197658f, 0.844060f, 1.188586f, 0.605435f, 1.174232f, 0.327060f, -0.094032f, -0.955794f, -1.048806f, -0.826196f, -0.304468f, 0.698768f, -0.495101f, -0.046607f, -0.016936f, -0.784415f, -0.032484f, 1.158664f, 0.959105f, 0.913943f, -0.118352f, 0.021282f}; + std::initializer_list Grid_data{TypeParam(0.355147f), TypeParam(-0.222342f), TypeParam(-1.197658f), TypeParam(0.844060f), TypeParam(1.188586f), TypeParam(0.605435f), TypeParam(1.174232f), TypeParam(0.327060f), TypeParam(-0.094032f), TypeParam(-0.955794f), TypeParam(-1.048806f), TypeParam(-0.826196f), TypeParam(-0.304468f), TypeParam(0.698768f), TypeParam(-0.495101f), TypeParam(-0.046607f), TypeParam(-0.016936f), TypeParam(-0.784415f), TypeParam(-0.032484f), TypeParam(1.158664f), TypeParam(0.959105f), TypeParam(0.913943f), TypeParam(-0.118352f), TypeParam(0.021282f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-0.437542f, 0.991216f, 0.409636f, -0.437542f, 0.079043f, 0.079043f, 0.789633f, 0.538178f, 0.414847f, 0.789633f, 1.050403f, 1.050403f, -1.213159f, -0.179968f, 2.221617f, -1.213159f, 0.499103f, -0.179968f, 1.059090f, -0.114847f, -0.374865f, 1.059090f, 0.438870f, -0.114847f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.437542f), TypeParam(0.991216f), TypeParam(0.409636f), TypeParam(-0.437542f), TypeParam(0.079043f), TypeParam(0.079043f), TypeParam(0.789633f), TypeParam(0.538178f), TypeParam(0.414847f), TypeParam(0.789633f), TypeParam(1.050403f), TypeParam(1.050403f), TypeParam(-1.213159f), TypeParam(-0.179968f), TypeParam(2.221617f), TypeParam(-1.213159f), TypeParam(0.499103f), TypeParam(-0.179968f), TypeParam(1.059090f), TypeParam(-0.114847f), TypeParam(-0.374865f), TypeParam(1.059090f), TypeParam(0.438870f), TypeParam(-0.114847f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_5D_nearest_reflection_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_nearest_reflection_align_corners) { OpTester test("GridSample", 20); std::string mode = "nearest"; std::string padding_mode = "reflection"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 3, 2}; - std::initializer_list X_data{0.189379f, 0.825309f, -0.701365f, 0.787800f, -1.102514f, 0.126954f, 1.824453f, -0.144635f, -1.712534f, 0.361739f, -0.462516f, -2.153102f, 0.536963f, 0.581639f, -1.325014f, -1.314673f, -0.524797f, -1.304159f, -1.093757f, -1.703444f, -0.672976f, 0.505303f, 1.497654f, -0.545441f, -1.334648f, 0.474489f, 0.484384f, 0.434399f, -0.733471f, 0.452991f, 0.324606f, -1.307459f, -0.640603f, -0.450100f, 0.772854f, 1.281813f, -0.481714f, 1.224667f, -0.437546f, 0.371986f, -0.320368f, -1.011020f, -1.199298f, 0.213302f, 1.795444f, 0.409271f, 1.328065f, -1.037527f, 0.224494f, 0.217863f, -0.925740f, 0.344755f, -1.445667f, -0.935542f, -0.427280f, -2.010803f, -1.174929f, 1.434105f, -1.168630f, 0.321896f, -0.561974f, -0.209305f, -1.063838f, 1.451708f, 0.266913f, -0.132535f, 0.798299f, 0.619547f, -0.324459f, 0.255630f, 0.488773f, -0.142060f}; + std::initializer_list X_data{TypeParam(0.189379f), TypeParam(0.825309f), TypeParam(-0.701365f), TypeParam(0.787800f), TypeParam(-1.102514f), TypeParam(0.126954f), TypeParam(1.824453f), TypeParam(-0.144635f), TypeParam(-1.712534f), TypeParam(0.361739f), TypeParam(-0.462516f), TypeParam(-2.153102f), TypeParam(0.536963f), TypeParam(0.581639f), TypeParam(-1.325014f), TypeParam(-1.314673f), TypeParam(-0.524797f), TypeParam(-1.304159f), TypeParam(-1.093757f), TypeParam(-1.703444f), TypeParam(-0.672976f), TypeParam(0.505303f), TypeParam(1.497654f), TypeParam(-0.545441f), TypeParam(-1.334648f), TypeParam(0.474489f), TypeParam(0.484384f), TypeParam(0.434399f), TypeParam(-0.733471f), TypeParam(0.452991f), TypeParam(0.324606f), TypeParam(-1.307459f), TypeParam(-0.640603f), TypeParam(-0.450100f), TypeParam(0.772854f), TypeParam(1.281813f), TypeParam(-0.481714f), TypeParam(1.224667f), TypeParam(-0.437546f), TypeParam(0.371986f), TypeParam(-0.320368f), TypeParam(-1.011020f), TypeParam(-1.199298f), TypeParam(0.213302f), TypeParam(1.795444f), TypeParam(0.409271f), TypeParam(1.328065f), TypeParam(-1.037527f), TypeParam(0.224494f), TypeParam(0.217863f), TypeParam(-0.925740f), TypeParam(0.344755f), TypeParam(-1.445667f), TypeParam(-0.935542f), TypeParam(-0.427280f), TypeParam(-2.010803f), TypeParam(-1.174929f), TypeParam(1.434105f), TypeParam(-1.168630f), TypeParam(0.321896f), TypeParam(-0.561974f), TypeParam(-0.209305f), TypeParam(-1.063838f), TypeParam(1.451708f), TypeParam(0.266913f), TypeParam(-0.132535f), TypeParam(0.798299f), TypeParam(0.619547f), TypeParam(-0.324459f), TypeParam(0.255630f), TypeParam(0.488773f), TypeParam(-0.142060f)}; std::initializer_list Grid_shape{2, 3, 3, 2, 3}; - std::initializer_list Grid_data{-0.034431f, 1.048250f, 0.160255f, -0.446426f, 0.879791f, -0.683555f, 0.039704f, 0.269729f, 0.538601f, -1.107191f, 0.058867f, -0.310704f, 0.778040f, 0.403733f, 0.480956f, 0.721512f, -0.268657f, -0.076883f, 0.962704f, -0.967187f, -0.829464f, 0.087786f, -0.475353f, 0.068725f, 1.060032f, -0.139108f, -1.023162f, -0.545493f, 1.102040f, -0.263627f, -0.526173f, 0.540152f, 0.148556f, -1.058015f, 0.999344f, 0.675750f, 1.043022f, 0.525119f, -0.404585f, -0.391737f, 0.581547f, -0.232625f, 0.235264f, -1.162786f, -0.593187f, 0.445737f, -0.059159f, -0.576901f, -1.046721f, 0.762672f, -0.241271f, -1.179040f, 1.157741f, 0.583952f, -0.717767f, -0.875798f, 1.159575f, 0.005010f, -0.721707f, 0.690536f, -0.249959f, 0.082204f, -0.625120f, -1.016394f, -0.796947f, -0.131764f, -0.868737f, 1.182731f, 0.012988f, -0.459398f, 0.474264f, -1.063883f, -0.613791f, 0.450721f, -1.019595f, 0.598084f, 0.100866f, -1.000569f, -1.190919f, 0.379261f, 0.567202f, -0.239888f, -1.061107f, -0.691616f, 0.127540f, 0.043657f, 0.307172f, 0.212184f, -0.062900f, 0.633272f, 1.164016f, 0.999377f, 1.090411f, -0.405004f, -0.409578f, -0.132722f, 0.354671f, 0.485734f, -0.106963f, -0.775112f, -0.905400f, 1.155262f, -0.322627f, -0.162203f, -0.735432f, -0.594912f, 0.263568f, 0.505424f}; + std::initializer_list Grid_data{TypeParam(-0.034431f), TypeParam(1.048250f), TypeParam(0.160255f), TypeParam(-0.446426f), TypeParam(0.879791f), TypeParam(-0.683555f), TypeParam(0.039704f), TypeParam(0.269729f), TypeParam(0.538601f), TypeParam(-1.107191f), TypeParam(0.058867f), TypeParam(-0.310704f), TypeParam(0.778040f), TypeParam(0.403733f), TypeParam(0.480956f), TypeParam(0.721512f), TypeParam(-0.268657f), TypeParam(-0.076883f), TypeParam(0.962704f), TypeParam(-0.967187f), TypeParam(-0.829464f), TypeParam(0.087786f), TypeParam(-0.475353f), TypeParam(0.068725f), TypeParam(1.060032f), TypeParam(-0.139108f), TypeParam(-1.023162f), TypeParam(-0.545493f), TypeParam(1.102040f), TypeParam(-0.263627f), TypeParam(-0.526173f), TypeParam(0.540152f), TypeParam(0.148556f), TypeParam(-1.058015f), TypeParam(0.999344f), TypeParam(0.675750f), TypeParam(1.043022f), TypeParam(0.525119f), TypeParam(-0.404585f), TypeParam(-0.391737f), TypeParam(0.581547f), TypeParam(-0.232625f), TypeParam(0.235264f), TypeParam(-1.162786f), TypeParam(-0.593187f), TypeParam(0.445737f), TypeParam(-0.059159f), TypeParam(-0.576901f), TypeParam(-1.046721f), TypeParam(0.762672f), TypeParam(-0.241271f), TypeParam(-1.179040f), TypeParam(1.157741f), TypeParam(0.583952f), TypeParam(-0.717767f), TypeParam(-0.875798f), TypeParam(1.159575f), TypeParam(0.005010f), TypeParam(-0.721707f), TypeParam(0.690536f), TypeParam(-0.249959f), TypeParam(0.082204f), TypeParam(-0.625120f), TypeParam(-1.016394f), TypeParam(-0.796947f), TypeParam(-0.131764f), TypeParam(-0.868737f), TypeParam(1.182731f), TypeParam(0.012988f), TypeParam(-0.459398f), TypeParam(0.474264f), TypeParam(-1.063883f), TypeParam(-0.613791f), TypeParam(0.450721f), TypeParam(-1.019595f), TypeParam(0.598084f), TypeParam(0.100866f), TypeParam(-1.000569f), TypeParam(-1.190919f), TypeParam(0.379261f), TypeParam(0.567202f), TypeParam(-0.239888f), TypeParam(-1.061107f), TypeParam(-0.691616f), TypeParam(0.127540f), TypeParam(0.043657f), TypeParam(0.307172f), TypeParam(0.212184f), TypeParam(-0.062900f), TypeParam(0.633272f), TypeParam(1.164016f), TypeParam(0.999377f), TypeParam(1.090411f), TypeParam(-0.405004f), TypeParam(-0.409578f), TypeParam(-0.132722f), TypeParam(0.354671f), TypeParam(0.485734f), TypeParam(-0.106963f), TypeParam(-0.775112f), TypeParam(-0.905400f), TypeParam(1.155262f), TypeParam(-0.322627f), TypeParam(-0.162203f), TypeParam(-0.735432f), TypeParam(-0.594912f), TypeParam(0.263568f), TypeParam(0.505424f)}; std::initializer_list Y_shape{2, 2, 3, 3, 2}; - std::initializer_list Y_data{-0.462516f, -1.102514f, -1.314673f, -1.712534f, 0.361739f, 0.361739f, 0.825309f, 0.361739f, 0.787800f, -0.462516f, -0.462516f, -0.524797f, -2.153102f, -0.462516f, 0.825309f, 0.787800f, -0.462516f, -0.524797f, -0.733471f, 1.497654f, -0.450100f, 0.484384f, 0.434399f, 0.434399f, -1.703444f, 0.434399f, 0.505303f, -0.733471f, -0.733471f, 0.772854f, 0.452991f, -0.733471f, -1.703444f, 0.505303f, -0.733471f, 0.772854f, 0.224494f, 0.217863f, -0.437546f, -1.199298f, 1.328065f, -0.437546f, -0.437546f, 0.371986f, -0.925740f, -0.481714f, 0.409271f, 0.344755f, -0.935542f, 1.795444f, 0.409271f, 0.224494f, -0.437546f, -0.925740f, 0.798299f, 0.619547f, -1.174929f, -0.561974f, 0.266913f, -1.174929f, -1.174929f, 1.434105f, -0.324459f, -0.427280f, 1.451708f, 0.255630f, -0.142060f, -1.063838f, 1.451708f, 0.798299f, -1.174929f, -0.324459f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.462516f), TypeParam(-1.102514f), TypeParam(-1.314673f), TypeParam(-1.712534f), TypeParam(0.361739f), TypeParam(0.361739f), TypeParam(0.825309f), TypeParam(0.361739f), TypeParam(0.787800f), TypeParam(-0.462516f), TypeParam(-0.462516f), TypeParam(-0.524797f), TypeParam(-2.153102f), TypeParam(-0.462516f), TypeParam(0.825309f), TypeParam(0.787800f), TypeParam(-0.462516f), TypeParam(-0.524797f), TypeParam(-0.733471f), TypeParam(1.497654f), TypeParam(-0.450100f), TypeParam(0.484384f), TypeParam(0.434399f), TypeParam(0.434399f), TypeParam(-1.703444f), TypeParam(0.434399f), TypeParam(0.505303f), TypeParam(-0.733471f), TypeParam(-0.733471f), TypeParam(0.772854f), TypeParam(0.452991f), TypeParam(-0.733471f), TypeParam(-1.703444f), TypeParam(0.505303f), TypeParam(-0.733471f), TypeParam(0.772854f), TypeParam(0.224494f), TypeParam(0.217863f), TypeParam(-0.437546f), TypeParam(-1.199298f), TypeParam(1.328065f), TypeParam(-0.437546f), TypeParam(-0.437546f), TypeParam(0.371986f), TypeParam(-0.925740f), TypeParam(-0.481714f), TypeParam(0.409271f), TypeParam(0.344755f), TypeParam(-0.935542f), TypeParam(1.795444f), TypeParam(0.409271f), TypeParam(0.224494f), TypeParam(-0.437546f), TypeParam(-0.925740f), TypeParam(0.798299f), TypeParam(0.619547f), TypeParam(-1.174929f), TypeParam(-0.561974f), TypeParam(0.266913f), TypeParam(-1.174929f), TypeParam(-1.174929f), TypeParam(1.434105f), TypeParam(-0.324459f), TypeParam(-0.427280f), TypeParam(1.451708f), TypeParam(0.255630f), TypeParam(-0.142060f), TypeParam(-1.063838f), TypeParam(1.451708f), TypeParam(0.798299f), TypeParam(-1.174929f), TypeParam(-0.324459f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_4D_nearest_reflection_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_nearest_reflection_no_align_corners) { OpTester test("GridSample", 20); std::string mode = "nearest"; std::string padding_mode = "reflection"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-0.769854f, -0.805659f, 0.813652f, -0.010183f, 0.276463f, -0.771678f, -2.563015f, -1.243904f, 2.365071f, 0.730651f, -0.068795f, -1.495438f, 0.211578f, -1.042373f, 0.884036f, -0.746288f, 1.011368f, 0.194463f, -0.307214f, 0.556053f, 0.629364f, 0.083601f, 0.248627f, -0.822453f}; + std::initializer_list X_data{TypeParam(-0.769854f), TypeParam(-0.805659f), TypeParam(0.813652f), TypeParam(-0.010183f), TypeParam(0.276463f), TypeParam(-0.771678f), TypeParam(-2.563015f), TypeParam(-1.243904f), TypeParam(2.365071f), TypeParam(0.730651f), TypeParam(-0.068795f), TypeParam(-1.495438f), TypeParam(0.211578f), TypeParam(-1.042373f), TypeParam(0.884036f), TypeParam(-0.746288f), TypeParam(1.011368f), TypeParam(0.194463f), TypeParam(-0.307214f), TypeParam(0.556053f), TypeParam(0.629364f), TypeParam(0.083601f), TypeParam(0.248627f), TypeParam(-0.822453f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.569884f, 1.163780f, -0.977608f, -0.145509f, 0.651234f, 1.099753f, -0.853766f, 0.509955f, 0.495437f, 0.723445f, -0.827299f, 0.856340f, -0.522676f, -0.738659f, 0.238269f, 1.016568f, -0.794666f, 0.640690f, -0.137431f, 0.383085f, 0.936085f, 0.325824f, -0.996188f, -0.361291f}; + std::initializer_list Grid_data{TypeParam(0.569884f), TypeParam(1.163780f), TypeParam(-0.977608f), TypeParam(-0.145509f), TypeParam(0.651234f), TypeParam(1.099753f), TypeParam(-0.853766f), TypeParam(0.509955f), TypeParam(0.495437f), TypeParam(0.723445f), TypeParam(-0.827299f), TypeParam(0.856340f), TypeParam(-0.522676f), TypeParam(-0.738659f), TypeParam(0.238269f), TypeParam(1.016568f), TypeParam(-0.794666f), TypeParam(0.640690f), TypeParam(-0.137431f), TypeParam(0.383085f), TypeParam(0.936085f), TypeParam(0.325824f), TypeParam(-0.996188f), TypeParam(-0.361291f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-0.771678f, 0.813652f, -0.771678f, 0.276463f, -0.771678f, 0.276463f, -1.495438f, 2.365071f, -1.495438f, -0.068795f, -1.495438f, -0.068795f, 0.211578f, 0.194463f, 1.011368f, 1.011368f, -0.746288f, 0.211578f, -0.307214f, -0.822453f, 0.248627f, 0.248627f, 0.083601f, -0.307214f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.771678f), TypeParam(0.813652f), TypeParam(-0.771678f), TypeParam(0.276463f), TypeParam(-0.771678f), TypeParam(0.276463f), TypeParam(-1.495438f), TypeParam(2.365071f), TypeParam(-1.495438f), TypeParam(-0.068795f), TypeParam(-1.495438f), TypeParam(-0.068795f), TypeParam(0.211578f), TypeParam(0.194463f), TypeParam(1.011368f), TypeParam(1.011368f), TypeParam(-0.746288f), TypeParam(0.211578f), TypeParam(-0.307214f), TypeParam(-0.822453f), TypeParam(0.248627f), TypeParam(0.248627f), TypeParam(0.083601f), TypeParam(-0.307214f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_5D_nearest_reflection_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_nearest_reflection_no_align_corners) { OpTester test("GridSample", 20); std::string mode = "nearest"; std::string padding_mode = "reflection"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 3, 2}; - std::initializer_list X_data{-0.185898f, 0.403325f, 0.737314f, 0.545995f, -1.010481f, -1.204522f, -0.147342f, 0.232425f, -1.339485f, 0.013892f, -1.098319f, 0.478079f, 0.051159f, -0.906061f, -0.428560f, 0.583460f, 1.137472f, 1.487881f, 1.349931f, -0.118774f, 0.436410f, 1.334689f, -1.115846f, 0.159820f, 0.617671f, 0.546630f, 1.861115f, 0.500044f, 0.623446f, 0.541840f, -0.279259f, -0.573875f, 0.783115f, -1.125017f, -1.166457f, -0.827232f, 0.273074f, 0.702953f, 1.288608f, -1.037043f, 0.021860f, 0.575628f, -0.034170f, 1.400741f, 0.508057f, 0.994702f, -2.267981f, 1.677437f, 0.175134f, 0.712679f, -0.440408f, -1.248550f, 1.618839f, -0.214598f, 0.486398f, -0.478466f, 0.912471f, 0.499651f, -0.886606f, -0.929524f, 0.449260f, 0.017969f, -0.050906f, 1.799695f, -0.033007f, -1.884108f, -1.392415f, -0.852990f, -0.052969f, 0.819434f, 0.089723f, 0.598047f}; + std::initializer_list X_data{TypeParam(-0.185898f), TypeParam(0.403325f), TypeParam(0.737314f), TypeParam(0.545995f), TypeParam(-1.010481f), TypeParam(-1.204522f), TypeParam(-0.147342f), TypeParam(0.232425f), TypeParam(-1.339485f), TypeParam(0.013892f), TypeParam(-1.098319f), TypeParam(0.478079f), TypeParam(0.051159f), TypeParam(-0.906061f), TypeParam(-0.428560f), TypeParam(0.583460f), TypeParam(1.137472f), TypeParam(1.487881f), TypeParam(1.349931f), TypeParam(-0.118774f), TypeParam(0.436410f), TypeParam(1.334689f), TypeParam(-1.115846f), TypeParam(0.159820f), TypeParam(0.617671f), TypeParam(0.546630f), TypeParam(1.861115f), TypeParam(0.500044f), TypeParam(0.623446f), TypeParam(0.541840f), TypeParam(-0.279259f), TypeParam(-0.573875f), TypeParam(0.783115f), TypeParam(-1.125017f), TypeParam(-1.166457f), TypeParam(-0.827232f), TypeParam(0.273074f), TypeParam(0.702953f), TypeParam(1.288608f), TypeParam(-1.037043f), TypeParam(0.021860f), TypeParam(0.575628f), TypeParam(-0.034170f), TypeParam(1.400741f), TypeParam(0.508057f), TypeParam(0.994702f), TypeParam(-2.267981f), TypeParam(1.677437f), TypeParam(0.175134f), TypeParam(0.712679f), TypeParam(-0.440408f), TypeParam(-1.248550f), TypeParam(1.618839f), TypeParam(-0.214598f), TypeParam(0.486398f), TypeParam(-0.478466f), TypeParam(0.912471f), TypeParam(0.499651f), TypeParam(-0.886606f), TypeParam(-0.929524f), TypeParam(0.449260f), TypeParam(0.017969f), TypeParam(-0.050906f), TypeParam(1.799695f), TypeParam(-0.033007f), TypeParam(-1.884108f), TypeParam(-1.392415f), TypeParam(-0.852990f), TypeParam(-0.052969f), TypeParam(0.819434f), TypeParam(0.089723f), TypeParam(0.598047f)}; std::initializer_list Grid_shape{2, 3, 3, 2, 3}; - std::initializer_list Grid_data{-0.118828f, 0.082315f, 0.328488f, -0.834821f, -0.138863f, -0.988801f, -0.976128f, 0.156412f, -1.171383f, 0.319534f, -1.105438f, -0.834991f, -0.248995f, -1.145138f, 0.969159f, 0.983228f, -0.626795f, 0.251376f, 0.613890f, 0.381328f, -0.160747f, -1.131853f, 0.872567f, -1.052516f, -0.222240f, 0.074438f, -0.395210f, -0.438906f, -1.037125f, 0.066119f, -0.136254f, 1.046163f, -0.395065f, 0.927498f, 0.056808f, -0.539139f, -0.285382f, -0.136177f, 0.012430f, -0.197703f, 0.356128f, 0.988219f, 0.188620f, 0.434655f, 0.741024f, 0.258662f, 0.553165f, 0.629461f, 1.123216f, -1.095185f, 0.410630f, -0.054374f, -0.215508f, -0.462650f, 0.721441f, 1.097745f, -0.979308f, 0.648336f, 0.827460f, 0.209729f, 0.014136f, 0.923431f, 0.035578f, -0.299309f, -0.088614f, 0.385002f, 0.300407f, -0.064744f, 0.378800f, 0.323185f, -0.972071f, 0.299012f, 0.734213f, 0.137618f, -0.109532f, 0.919238f, -1.048417f, -0.547724f, -0.542389f, 1.036863f, -1.160666f, 0.119013f, -1.162427f, -0.039461f, 0.447285f, -0.280625f, 1.164882f, 0.003820f, -0.611796f, 0.309439f, 0.624077f, -0.002384f, 1.026569f, -0.759499f, 0.512014f, 0.681403f, 0.596030f, -0.000440f, 0.342557f, -0.941414f, -0.941707f, -0.074588f, -0.150400f, 0.891031f, 0.871352f, 0.813657f, -0.549640f, -0.942044f}; + std::initializer_list Grid_data{TypeParam(-0.118828f), TypeParam(0.082315f), TypeParam(0.328488f), TypeParam(-0.834821f), TypeParam(-0.138863f), TypeParam(-0.988801f), TypeParam(-0.976128f), TypeParam(0.156412f), TypeParam(-1.171383f), TypeParam(0.319534f), TypeParam(-1.105438f), TypeParam(-0.834991f), TypeParam(-0.248995f), TypeParam(-1.145138f), TypeParam(0.969159f), TypeParam(0.983228f), TypeParam(-0.626795f), TypeParam(0.251376f), TypeParam(0.613890f), TypeParam(0.381328f), TypeParam(-0.160747f), TypeParam(-1.131853f), TypeParam(0.872567f), TypeParam(-1.052516f), TypeParam(-0.222240f), TypeParam(0.074438f), TypeParam(-0.395210f), TypeParam(-0.438906f), TypeParam(-1.037125f), TypeParam(0.066119f), TypeParam(-0.136254f), TypeParam(1.046163f), TypeParam(-0.395065f), TypeParam(0.927498f), TypeParam(0.056808f), TypeParam(-0.539139f), TypeParam(-0.285382f), TypeParam(-0.136177f), TypeParam(0.012430f), TypeParam(-0.197703f), TypeParam(0.356128f), TypeParam(0.988219f), TypeParam(0.188620f), TypeParam(0.434655f), TypeParam(0.741024f), TypeParam(0.258662f), TypeParam(0.553165f), TypeParam(0.629461f), TypeParam(1.123216f), TypeParam(-1.095185f), TypeParam(0.410630f), TypeParam(-0.054374f), TypeParam(-0.215508f), TypeParam(-0.462650f), TypeParam(0.721441f), TypeParam(1.097745f), TypeParam(-0.979308f), TypeParam(0.648336f), TypeParam(0.827460f), TypeParam(0.209729f), TypeParam(0.014136f), TypeParam(0.923431f), TypeParam(0.035578f), TypeParam(-0.299309f), TypeParam(-0.088614f), TypeParam(0.385002f), TypeParam(0.300407f), TypeParam(-0.064744f), TypeParam(0.378800f), TypeParam(0.323185f), TypeParam(-0.972071f), TypeParam(0.299012f), TypeParam(0.734213f), TypeParam(0.137618f), TypeParam(-0.109532f), TypeParam(0.919238f), TypeParam(-1.048417f), TypeParam(-0.547724f), TypeParam(-0.542389f), TypeParam(1.036863f), TypeParam(-1.160666f), TypeParam(0.119013f), TypeParam(-1.162427f), TypeParam(-0.039461f), TypeParam(0.447285f), TypeParam(-0.280625f), TypeParam(1.164882f), TypeParam(0.003820f), TypeParam(-0.611796f), TypeParam(0.309439f), TypeParam(0.624077f), TypeParam(-0.002384f), TypeParam(1.026569f), TypeParam(-0.759499f), TypeParam(0.512014f), TypeParam(0.681403f), TypeParam(0.596030f), TypeParam(-0.000440f), TypeParam(0.342557f), TypeParam(-0.941414f), TypeParam(-0.941707f), TypeParam(-0.074588f), TypeParam(-0.150400f), TypeParam(0.891031f), TypeParam(0.871352f), TypeParam(0.813657f), TypeParam(-0.549640f), TypeParam(-0.942044f)}; std::initializer_list Y_shape{2, 2, 3, 3, 2}; - std::initializer_list Y_data{-1.339485f, 0.737314f, 0.737314f, 0.403325f, 0.051159f, 0.232425f, 0.478079f, -1.010481f, 0.737314f, -0.147342f, -1.010481f, 0.545995f, -1.339485f, 1.137472f, 1.487881f, 1.487881f, -0.906061f, 0.737314f, 1.861115f, 0.436410f, 0.436410f, -0.118774f, -0.279259f, 0.546630f, 0.541840f, -1.115846f, 0.436410f, 0.617671f, -1.115846f, 1.334689f, 1.861115f, -1.166457f, -0.827232f, -0.827232f, -0.573875f, 0.436410f, 0.575628f, 1.677437f, 1.677437f, -0.440408f, -1.248550f, 1.400741f, 0.994702f, 0.702953f, 0.021860f, 1.400741f, -1.248550f, 1.400741f, -1.248550f, 1.618839f, -1.248550f, -0.034170f, 1.618839f, 0.702953f, -0.929524f, -1.884108f, -1.884108f, -0.052969f, 0.819434f, 0.017969f, 1.799695f, -0.478466f, -0.886606f, 0.017969f, 0.819434f, 0.017969f, 0.819434f, 0.089723f, 0.819434f, 0.449260f, 0.089723f, -0.478466f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-1.339485f), TypeParam(0.737314f), TypeParam(0.737314f), TypeParam(0.403325f), TypeParam(0.051159f), TypeParam(0.232425f), TypeParam(0.478079f), TypeParam(-1.010481f), TypeParam(0.737314f), TypeParam(-0.147342f), TypeParam(-1.010481f), TypeParam(0.545995f), TypeParam(-1.339485f), TypeParam(1.137472f), TypeParam(1.487881f), TypeParam(1.487881f), TypeParam(-0.906061f), TypeParam(0.737314f), TypeParam(1.861115f), TypeParam(0.436410f), TypeParam(0.436410f), TypeParam(-0.118774f), TypeParam(-0.279259f), TypeParam(0.546630f), TypeParam(0.541840f), TypeParam(-1.115846f), TypeParam(0.436410f), TypeParam(0.617671f), TypeParam(-1.115846f), TypeParam(1.334689f), TypeParam(1.861115f), TypeParam(-1.166457f), TypeParam(-0.827232f), TypeParam(-0.827232f), TypeParam(-0.573875f), TypeParam(0.436410f), TypeParam(0.575628f), TypeParam(1.677437f), TypeParam(1.677437f), TypeParam(-0.440408f), TypeParam(-1.248550f), TypeParam(1.400741f), TypeParam(0.994702f), TypeParam(0.702953f), TypeParam(0.021860f), TypeParam(1.400741f), TypeParam(-1.248550f), TypeParam(1.400741f), TypeParam(-1.248550f), TypeParam(1.618839f), TypeParam(-1.248550f), TypeParam(-0.034170f), TypeParam(1.618839f), TypeParam(0.702953f), TypeParam(-0.929524f), TypeParam(-1.884108f), TypeParam(-1.884108f), TypeParam(-0.052969f), TypeParam(0.819434f), TypeParam(0.017969f), TypeParam(1.799695f), TypeParam(-0.478466f), TypeParam(-0.886606f), TypeParam(0.017969f), TypeParam(0.819434f), TypeParam(0.017969f), TypeParam(0.819434f), TypeParam(0.089723f), TypeParam(0.819434f), TypeParam(0.449260f), TypeParam(0.089723f), TypeParam(-0.478466f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_zeros_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_zeros_align_corners) { OpTester test("GridSample", 20); std::string mode = "linear"; std::string padding_mode = "zeros"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{0.010274f, 1.493496f, -0.264303f, 0.035897f, -0.751962f, -0.370195f, -0.514836f, 0.399928f, -0.191651f, -0.239505f, -1.931184f, -1.074773f, -0.121908f, 0.050673f, -0.741501f, -0.229127f, -0.360925f, 0.264077f, 1.537180f, 1.603202f, -1.241810f, -0.388456f, -0.609742f, 0.095097f}; + std::initializer_list X_data{TypeParam(0.010274f), TypeParam(1.493496f), TypeParam(-0.264303f), TypeParam(0.035897f), TypeParam(-0.751962f), TypeParam(-0.370195f), TypeParam(-0.514836f), TypeParam(0.399928f), TypeParam(-0.191651f), TypeParam(-0.239505f), TypeParam(-1.931184f), TypeParam(-1.074773f), TypeParam(-0.121908f), TypeParam(0.050673f), TypeParam(-0.741501f), TypeParam(-0.229127f), TypeParam(-0.360925f), TypeParam(0.264077f), TypeParam(1.537180f), TypeParam(1.603202f), TypeParam(-1.241810f), TypeParam(-0.388456f), TypeParam(-0.609742f), TypeParam(0.095097f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{-0.118589f, -0.020968f, -0.893597f, 1.170924f, -0.517539f, 0.698168f, -0.672718f, 0.008056f, 0.410793f, -1.101817f, 0.550440f, -0.918534f, 0.167456f, -0.237959f, 0.687868f, 1.166281f, 0.270439f, -0.034265f, -0.594534f, 0.447403f, -0.577587f, 0.495680f, -0.520113f, 0.813977f}; + std::initializer_list Grid_data{TypeParam(-0.118589f), TypeParam(-0.020968f), TypeParam(-0.893597f), TypeParam(1.170924f), TypeParam(-0.517539f), TypeParam(0.698168f), TypeParam(-0.672718f), TypeParam(0.008056f), TypeParam(0.410793f), TypeParam(-1.101817f), TypeParam(0.550440f), TypeParam(-0.918534f), TypeParam(0.167456f), TypeParam(-0.237959f), TypeParam(0.687868f), TypeParam(1.166281f), TypeParam(0.270439f), TypeParam(-0.034265f), TypeParam(-0.594534f), TypeParam(0.447403f), TypeParam(-0.577587f), TypeParam(0.495680f), TypeParam(-0.520113f), TypeParam(0.813977f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-0.115313f, -0.606595f, -0.518616f, -0.218999f, 0.948961f, 1.063015f, -0.210622f, -1.563324f, -1.265386f, -0.212304f, 0.117155f, 0.159843f, -0.342175f, 0.138844f, -0.402196f, -0.457139f, -0.432849f, -0.286783f, -0.191760f, -0.012426f, -0.621658f, -0.799488f, -0.763820f, -0.551571f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.115313f), TypeParam(-0.606595f), TypeParam(-0.518616f), TypeParam(-0.218999f), TypeParam(0.948961f), TypeParam(1.063015f), TypeParam(-0.210622f), TypeParam(-1.563324f), TypeParam(-1.265386f), TypeParam(-0.212304f), TypeParam(0.117155f), TypeParam(0.159843f), TypeParam(-0.342175f), TypeParam(0.138844f), TypeParam(-0.402196f), TypeParam(-0.457139f), TypeParam(-0.432849f), TypeParam(-0.286783f), TypeParam(-0.191760f), TypeParam(-0.012426f), TypeParam(-0.621658f), TypeParam(-0.799488f), TypeParam(-0.763820f), TypeParam(-0.551571f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_zeros_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_zeros_align_corners) { OpTester test("GridSample", 20); std::string mode = "linear"; std::string padding_mode = "zeros"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 3, 2}; - std::initializer_list X_data{-1.787070f, -0.894227f, -0.113069f, 0.713917f, 0.041566f, -1.847208f, 0.013441f, -1.439041f, 1.051864f, 1.576791f, 1.180527f, -1.457019f, 0.298446f, 1.142738f, -0.961347f, -0.471509f, -0.074154f, 0.047739f, -0.679950f, -2.306940f, -0.552171f, -0.357144f, -0.492247f, -0.455872f, 0.399680f, 0.057915f, -0.362704f, 1.083763f, -0.084941f, -1.691393f, -1.913178f, 0.696366f, 1.172833f, 0.901506f, -1.189840f, -1.197158f, 0.007338f, 0.161468f, -1.048452f, -0.480832f, 0.391235f, 1.056413f, -0.116648f, 0.632195f, 0.840261f, -2.187738f, 0.302910f, -0.956190f, -0.362645f, 0.771747f, 0.524840f, -0.954672f, -1.084612f, -0.525794f, -0.969691f, -1.056405f, -0.364709f, 0.336189f, -0.178281f, 1.015025f, -0.532580f, 0.036602f, -0.434395f, -1.208987f, -1.084039f, 0.642844f, -0.819208f, -0.982898f, -0.109210f, -1.231957f, 1.083089f, -0.870451f}; + std::initializer_list X_data{TypeParam(-1.787070f), TypeParam(-0.894227f), TypeParam(-0.113069f), TypeParam(0.713917f), TypeParam(0.041566f), TypeParam(-1.847208f), TypeParam(0.013441f), TypeParam(-1.439041f), TypeParam(1.051864f), TypeParam(1.576791f), TypeParam(1.180527f), TypeParam(-1.457019f), TypeParam(0.298446f), TypeParam(1.142738f), TypeParam(-0.961347f), TypeParam(-0.471509f), TypeParam(-0.074154f), TypeParam(0.047739f), TypeParam(-0.679950f), TypeParam(-2.306940f), TypeParam(-0.552171f), TypeParam(-0.357144f), TypeParam(-0.492247f), TypeParam(-0.455872f), TypeParam(0.399680f), TypeParam(0.057915f), TypeParam(-0.362704f), TypeParam(1.083763f), TypeParam(-0.084941f), TypeParam(-1.691393f), TypeParam(-1.913178f), TypeParam(0.696366f), TypeParam(1.172833f), TypeParam(0.901506f), TypeParam(-1.189840f), TypeParam(-1.197158f), TypeParam(0.007338f), TypeParam(0.161468f), TypeParam(-1.048452f), TypeParam(-0.480832f), TypeParam(0.391235f), TypeParam(1.056413f), TypeParam(-0.116648f), TypeParam(0.632195f), TypeParam(0.840261f), TypeParam(-2.187738f), TypeParam(0.302910f), TypeParam(-0.956190f), TypeParam(-0.362645f), TypeParam(0.771747f), TypeParam(0.524840f), TypeParam(-0.954672f), TypeParam(-1.084612f), TypeParam(-0.525794f), TypeParam(-0.969691f), TypeParam(-1.056405f), TypeParam(-0.364709f), TypeParam(0.336189f), TypeParam(-0.178281f), TypeParam(1.015025f), TypeParam(-0.532580f), TypeParam(0.036602f), TypeParam(-0.434395f), TypeParam(-1.208987f), TypeParam(-1.084039f), TypeParam(0.642844f), TypeParam(-0.819208f), TypeParam(-0.982898f), TypeParam(-0.109210f), TypeParam(-1.231957f), TypeParam(1.083089f), TypeParam(-0.870451f)}; std::initializer_list Grid_shape{2, 3, 3, 2, 3}; - std::initializer_list Grid_data{0.350638f, -0.554259f, 0.740901f, -1.134597f, -0.450763f, -0.706065f, -0.712365f, -0.727142f, -1.130749f, 0.205940f, -0.237380f, -1.010413f, -0.000494f, -0.199898f, 0.495032f, -0.939943f, -0.337590f, 0.247001f, 0.508664f, 0.090780f, 0.325198f, 1.199561f, -0.415694f, 0.817854f, 1.033666f, -1.061540f, 0.290273f, 0.679739f, -0.187185f, 0.662278f, 0.040817f, 0.913540f, 0.025838f, -0.768267f, 0.911326f, 0.356885f, 1.020923f, 0.297892f, 0.637209f, 0.748214f, 0.202064f, -0.278959f, 0.247841f, -0.836700f, 0.040996f, -0.385697f, 0.075869f, -0.950110f, 0.733227f, -1.107135f, 0.513890f, 0.790272f, -1.099795f, 1.084212f, -0.892061f, -0.235640f, 0.621837f, -0.380523f, 1.069422f, -0.529383f, -0.160661f, -0.784422f, -0.556715f, 1.171015f, 0.902476f, 0.088357f, 0.098667f, -1.018314f, 0.905937f, -0.179914f, -0.500513f, -0.954987f, 0.986618f, 0.569025f, 0.722795f, 0.124254f, -0.814285f, 0.491561f, 0.138395f, 0.402690f, -0.298810f, -0.566298f, 0.985118f, 0.402260f, -0.487031f, 0.107159f, -0.260850f, -0.102620f, 0.672911f, -0.955102f, 1.086040f, 0.807667f, 0.001031f, -0.490841f, 0.244670f, -0.794290f, 0.779461f, -0.634633f, 0.229290f, -1.180597f, 0.574650f, 0.812338f, 0.900697f, 0.097950f, 0.708525f, 0.409153f, 0.804739f, 0.677169f}; + std::initializer_list Grid_data{TypeParam(0.350638f), TypeParam(-0.554259f), TypeParam(0.740901f), TypeParam(-1.134597f), TypeParam(-0.450763f), TypeParam(-0.706065f), TypeParam(-0.712365f), TypeParam(-0.727142f), TypeParam(-1.130749f), TypeParam(0.205940f), TypeParam(-0.237380f), TypeParam(-1.010413f), TypeParam(-0.000494f), TypeParam(-0.199898f), TypeParam(0.495032f), TypeParam(-0.939943f), TypeParam(-0.337590f), TypeParam(0.247001f), TypeParam(0.508664f), TypeParam(0.090780f), TypeParam(0.325198f), TypeParam(1.199561f), TypeParam(-0.415694f), TypeParam(0.817854f), TypeParam(1.033666f), TypeParam(-1.061540f), TypeParam(0.290273f), TypeParam(0.679739f), TypeParam(-0.187185f), TypeParam(0.662278f), TypeParam(0.040817f), TypeParam(0.913540f), TypeParam(0.025838f), TypeParam(-0.768267f), TypeParam(0.911326f), TypeParam(0.356885f), TypeParam(1.020923f), TypeParam(0.297892f), TypeParam(0.637209f), TypeParam(0.748214f), TypeParam(0.202064f), TypeParam(-0.278959f), TypeParam(0.247841f), TypeParam(-0.836700f), TypeParam(0.040996f), TypeParam(-0.385697f), TypeParam(0.075869f), TypeParam(-0.950110f), TypeParam(0.733227f), TypeParam(-1.107135f), TypeParam(0.513890f), TypeParam(0.790272f), TypeParam(-1.099795f), TypeParam(1.084212f), TypeParam(-0.892061f), TypeParam(-0.235640f), TypeParam(0.621837f), TypeParam(-0.380523f), TypeParam(1.069422f), TypeParam(-0.529383f), TypeParam(-0.160661f), TypeParam(-0.784422f), TypeParam(-0.556715f), TypeParam(1.171015f), TypeParam(0.902476f), TypeParam(0.088357f), TypeParam(0.098667f), TypeParam(-1.018314f), TypeParam(0.905937f), TypeParam(-0.179914f), TypeParam(-0.500513f), TypeParam(-0.954987f), TypeParam(0.986618f), TypeParam(0.569025f), TypeParam(0.722795f), TypeParam(0.124254f), TypeParam(-0.814285f), TypeParam(0.491561f), TypeParam(0.138395f), TypeParam(0.402690f), TypeParam(-0.298810f), TypeParam(-0.566298f), TypeParam(0.985118f), TypeParam(0.402260f), TypeParam(-0.487031f), TypeParam(0.107159f), TypeParam(-0.260850f), TypeParam(-0.102620f), TypeParam(0.672911f), TypeParam(-0.955102f), TypeParam(1.086040f), TypeParam(0.807667f), TypeParam(0.001031f), TypeParam(-0.490841f), TypeParam(0.244670f), TypeParam(-0.794290f), TypeParam(0.779461f), TypeParam(-0.634633f), TypeParam(0.229290f), TypeParam(-1.180597f), TypeParam(0.574650f), TypeParam(0.812338f), TypeParam(0.900697f), TypeParam(0.097950f), TypeParam(0.708525f), TypeParam(0.409153f), TypeParam(0.804739f), TypeParam(0.677169f)}; std::initializer_list Y_shape{2, 2, 3, 3, 2}; - std::initializer_list Y_data{0.171946f, -0.411342f, -1.046998f, -0.002345f, 0.246533f, 0.396970f, 0.664278f, 0.199883f, -0.636287f, 0.162358f, -0.061161f, 0.528084f, 0.041846f, 0.750291f, -0.476442f, 0.142258f, -0.067844f, 0.869081f, 0.360025f, -0.406785f, -0.701985f, -0.718142f, 0.519179f, -0.022693f, 0.618451f, 0.708731f, 0.224429f, 0.784241f, -0.812606f, -0.521137f, 0.266524f, 0.190886f, 0.231077f, -0.465330f, 0.204730f, 0.348489f, 0.356190f, 0.256096f, -0.038212f, -0.943162f, 0.258902f, -0.360112f, -0.920536f, 0.126677f, -0.523600f, -0.361337f, -0.154168f, 0.179761f, -1.141155f, -0.423488f, -0.225410f, -0.204886f, -1.162816f, -0.678226f, -0.384409f, -0.146245f, -0.622531f, 0.312188f, -0.828836f, -0.541017f, -0.778291f, -0.602484f, -0.328754f, -0.163964f, -0.508068f, 0.193021f, 0.273133f, -0.217934f, -0.562420f, 0.287725f, -1.097279f, -0.306201f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(0.171946f), TypeParam(-0.411342f), TypeParam(-1.046998f), TypeParam(-0.002345f), TypeParam(0.246533f), TypeParam(0.396970f), TypeParam(0.664278f), TypeParam(0.199883f), TypeParam(-0.636287f), TypeParam(0.162358f), TypeParam(-0.061161f), TypeParam(0.528084f), TypeParam(0.041846f), TypeParam(0.750291f), TypeParam(-0.476442f), TypeParam(0.142258f), TypeParam(-0.067844f), TypeParam(0.869081f), TypeParam(0.360025f), TypeParam(-0.406785f), TypeParam(-0.701985f), TypeParam(-0.718142f), TypeParam(0.519179f), TypeParam(-0.022693f), TypeParam(0.618451f), TypeParam(0.708731f), TypeParam(0.224429f), TypeParam(0.784241f), TypeParam(-0.812606f), TypeParam(-0.521137f), TypeParam(0.266524f), TypeParam(0.190886f), TypeParam(0.231077f), TypeParam(-0.465330f), TypeParam(0.204730f), TypeParam(0.348489f), TypeParam(0.356190f), TypeParam(0.256096f), TypeParam(-0.038212f), TypeParam(-0.943162f), TypeParam(0.258902f), TypeParam(-0.360112f), TypeParam(-0.920536f), TypeParam(0.126677f), TypeParam(-0.523600f), TypeParam(-0.361337f), TypeParam(-0.154168f), TypeParam(0.179761f), TypeParam(-1.141155f), TypeParam(-0.423488f), TypeParam(-0.225410f), TypeParam(-0.204886f), TypeParam(-1.162816f), TypeParam(-0.678226f), TypeParam(-0.384409f), TypeParam(-0.146245f), TypeParam(-0.622531f), TypeParam(0.312188f), TypeParam(-0.828836f), TypeParam(-0.541017f), TypeParam(-0.778291f), TypeParam(-0.602484f), TypeParam(-0.328754f), TypeParam(-0.163964f), TypeParam(-0.508068f), TypeParam(0.193021f), TypeParam(0.273133f), TypeParam(-0.217934f), TypeParam(-0.562420f), TypeParam(0.287725f), TypeParam(-1.097279f), TypeParam(-0.306201f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_zeros_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_zeros_no_align_corners) { OpTester test("GridSample", 20); std::string mode = "linear"; std::string padding_mode = "zeros"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{0.185965f, 0.133937f, -0.763030f, 0.733342f, 1.932445f, -0.582571f, -1.312078f, 0.738952f, 0.444459f, 0.742593f, -0.805960f, -0.202535f, 0.970323f, -0.801176f, 0.277655f, -1.938051f, -1.879800f, 0.287116f, 0.261958f, -0.358247f, -0.107750f, 0.748162f, -0.742330f, 0.344665f}; + std::initializer_list X_data{TypeParam(0.185965f), TypeParam(0.133937f), TypeParam(-0.763030f), TypeParam(0.733342f), TypeParam(1.932445f), TypeParam(-0.582571f), TypeParam(-1.312078f), TypeParam(0.738952f), TypeParam(0.444459f), TypeParam(0.742593f), TypeParam(-0.805960f), TypeParam(-0.202535f), TypeParam(0.970323f), TypeParam(-0.801176f), TypeParam(0.277655f), TypeParam(-1.938051f), TypeParam(-1.879800f), TypeParam(0.287116f), TypeParam(0.261958f), TypeParam(-0.358247f), TypeParam(-0.107750f), TypeParam(0.748162f), TypeParam(-0.742330f), TypeParam(0.344665f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{-0.460252f, 0.734353f, -1.069308f, 1.005361f, 1.198595f, -0.327629f, 0.474026f, 1.196645f, 0.361782f, 0.469280f, 0.440632f, -0.490951f, 0.292918f, -0.639568f, 1.024697f, -0.514217f, 0.274326f, -0.347614f, 0.600117f, 0.019780f, 0.659824f, -0.324940f, -0.704174f, 0.460072f}; + std::initializer_list Grid_data{TypeParam(-0.460252f), TypeParam(0.734353f), TypeParam(-1.069308f), TypeParam(1.005361f), TypeParam(1.198595f), TypeParam(-0.327629f), TypeParam(0.474026f), TypeParam(1.196645f), TypeParam(0.361782f), TypeParam(0.469280f), TypeParam(0.440632f), TypeParam(-0.490951f), TypeParam(0.292918f), TypeParam(-0.639568f), TypeParam(1.024697f), TypeParam(-0.514217f), TypeParam(0.274326f), TypeParam(-0.347614f), TypeParam(0.600117f), TypeParam(0.019780f), TypeParam(0.659824f), TypeParam(-0.324940f), TypeParam(-0.704174f), TypeParam(0.460072f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{1.646426f, 0.409452f, 0.132247f, -0.106052f, -0.009495f, 0.270785f, -0.702581f, -0.170769f, 0.223282f, -0.044740f, 0.006388f, 0.645576f, -0.476802f, -0.504368f, -0.897503f, -1.684608f, -1.162742f, -0.963921f, -0.197266f, -0.050021f, 0.151796f, 0.662485f, 0.175502f, -0.434265f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(1.646426f), TypeParam(0.409452f), TypeParam(0.132247f), TypeParam(-0.106052f), TypeParam(-0.009495f), TypeParam(0.270785f), TypeParam(-0.702581f), TypeParam(-0.170769f), TypeParam(0.223282f), TypeParam(-0.044740f), TypeParam(0.006388f), TypeParam(0.645576f), TypeParam(-0.476802f), TypeParam(-0.504368f), TypeParam(-0.897503f), TypeParam(-1.684608f), TypeParam(-1.162742f), TypeParam(-0.963921f), TypeParam(-0.197266f), TypeParam(-0.050021f), TypeParam(0.151796f), TypeParam(0.662485f), TypeParam(0.175502f), TypeParam(-0.434265f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_zeros_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_zeros_no_align_corners) { OpTester test("GridSample", 20); std::string mode = "linear"; std::string padding_mode = "zeros"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 3, 2}; - std::initializer_list X_data{-0.299262f, -0.304887f, 0.906636f, -0.392850f, -0.050410f, 0.548199f, -1.235108f, -0.475848f, 0.635455f, 0.307462f, -1.241370f, -0.538672f, 0.863466f, 0.799983f, -0.090064f, -0.751721f, 0.956040f, -0.117709f, -2.183699f, -0.484444f, 1.105900f, 0.164466f, 0.720736f, 0.168044f, -0.656400f, 1.770106f, -0.544832f, 1.358424f, 0.981648f, -1.759268f, -0.526924f, 1.322339f, 0.148774f, 0.321413f, -1.257438f, -0.383775f, -2.117908f, -0.077921f, -0.197889f, 0.555813f, -1.517724f, 1.419652f, -0.891774f, 1.684663f, -1.524669f, -2.055758f, -0.299843f, -0.644860f, 0.428609f, -1.704372f, 1.257671f, -0.886508f, -0.029344f, -1.718824f, -0.294273f, 1.537690f, -1.366837f, -1.610098f, 0.650240f, -0.288219f, 0.837292f, 0.431683f, -0.405852f, 0.492271f, 0.416507f, 0.971658f, -0.183526f, 0.615709f, -0.081615f, 1.160796f, 1.431487f, 0.485687f}; + std::initializer_list X_data{TypeParam(-0.299262f), TypeParam(-0.304887f), TypeParam(0.906636f), TypeParam(-0.392850f), TypeParam(-0.050410f), TypeParam(0.548199f), TypeParam(-1.235108f), TypeParam(-0.475848f), TypeParam(0.635455f), TypeParam(0.307462f), TypeParam(-1.241370f), TypeParam(-0.538672f), TypeParam(0.863466f), TypeParam(0.799983f), TypeParam(-0.090064f), TypeParam(-0.751721f), TypeParam(0.956040f), TypeParam(-0.117709f), TypeParam(-2.183699f), TypeParam(-0.484444f), TypeParam(1.105900f), TypeParam(0.164466f), TypeParam(0.720736f), TypeParam(0.168044f), TypeParam(-0.656400f), TypeParam(1.770106f), TypeParam(-0.544832f), TypeParam(1.358424f), TypeParam(0.981648f), TypeParam(-1.759268f), TypeParam(-0.526924f), TypeParam(1.322339f), TypeParam(0.148774f), TypeParam(0.321413f), TypeParam(-1.257438f), TypeParam(-0.383775f), TypeParam(-2.117908f), TypeParam(-0.077921f), TypeParam(-0.197889f), TypeParam(0.555813f), TypeParam(-1.517724f), TypeParam(1.419652f), TypeParam(-0.891774f), TypeParam(1.684663f), TypeParam(-1.524669f), TypeParam(-2.055758f), TypeParam(-0.299843f), TypeParam(-0.644860f), TypeParam(0.428609f), TypeParam(-1.704372f), TypeParam(1.257671f), TypeParam(-0.886508f), TypeParam(-0.029344f), TypeParam(-1.718824f), TypeParam(-0.294273f), TypeParam(1.537690f), TypeParam(-1.366837f), TypeParam(-1.610098f), TypeParam(0.650240f), TypeParam(-0.288219f), TypeParam(0.837292f), TypeParam(0.431683f), TypeParam(-0.405852f), TypeParam(0.492271f), TypeParam(0.416507f), TypeParam(0.971658f), TypeParam(-0.183526f), TypeParam(0.615709f), TypeParam(-0.081615f), TypeParam(1.160796f), TypeParam(1.431487f), TypeParam(0.485687f)}; std::initializer_list Grid_shape{2, 3, 3, 2, 3}; - std::initializer_list Grid_data{0.884040f, -0.825214f, 0.496720f, -0.440955f, 1.195811f, 0.169268f, -1.042100f, 0.206524f, 0.145895f, -1.160650f, 0.240829f, 1.144915f, 0.345332f, -0.006382f, -0.248763f, 0.318888f, -0.534619f, 1.181719f, 1.037350f, 0.560600f, -0.446974f, -1.126746f, -0.690807f, 1.166754f, -1.101454f, -1.145775f, -0.086488f, 0.381780f, -1.194351f, -1.114106f, 0.006524f, -0.402521f, 0.836016f, 0.344533f, -1.041627f, -1.081571f, 0.824102f, -0.212785f, -0.524949f, 0.377977f, -0.235842f, 0.573897f, 0.304308f, -0.519568f, -0.961787f, 0.649611f, -0.720973f, -0.132725f, 0.164074f, -0.698360f, 0.653669f, -0.844065f, 0.294728f, 0.128341f, 0.440293f, -1.177701f, 0.069319f, 0.585007f, -0.768260f, 0.296941f, 0.004702f, 1.018020f, -0.254096f, 0.008198f, -0.521925f, -0.295744f, 0.343532f, -1.157334f, 0.910329f, 0.862921f, 0.508195f, 0.898317f, -0.373544f, 0.273330f, 0.061050f, -0.829794f, -0.461335f, -0.426012f, -0.296704f, -1.065526f, -0.843948f, -0.113955f, -0.182548f, -1.089296f, 0.256401f, 0.653393f, 0.999377f, 1.009925f, -0.838519f, -0.384579f, -0.569276f, 0.220093f, 0.321562f, 0.266984f, 0.701244f, 0.633093f, -0.644096f, 0.823778f, 0.809482f, 0.158802f, -1.044029f, -0.735991f, 0.334411f, 0.414891f, 1.118940f, 0.610743f, 0.434932f, -0.040928f}; + std::initializer_list Grid_data{TypeParam(0.884040f), TypeParam(-0.825214f), TypeParam(0.496720f), TypeParam(-0.440955f), TypeParam(1.195811f), TypeParam(0.169268f), TypeParam(-1.042100f), TypeParam(0.206524f), TypeParam(0.145895f), TypeParam(-1.160650f), TypeParam(0.240829f), TypeParam(1.144915f), TypeParam(0.345332f), TypeParam(-0.006382f), TypeParam(-0.248763f), TypeParam(0.318888f), TypeParam(-0.534619f), TypeParam(1.181719f), TypeParam(1.037350f), TypeParam(0.560600f), TypeParam(-0.446974f), TypeParam(-1.126746f), TypeParam(-0.690807f), TypeParam(1.166754f), TypeParam(-1.101454f), TypeParam(-1.145775f), TypeParam(-0.086488f), TypeParam(0.381780f), TypeParam(-1.194351f), TypeParam(-1.114106f), TypeParam(0.006524f), TypeParam(-0.402521f), TypeParam(0.836016f), TypeParam(0.344533f), TypeParam(-1.041627f), TypeParam(-1.081571f), TypeParam(0.824102f), TypeParam(-0.212785f), TypeParam(-0.524949f), TypeParam(0.377977f), TypeParam(-0.235842f), TypeParam(0.573897f), TypeParam(0.304308f), TypeParam(-0.519568f), TypeParam(-0.961787f), TypeParam(0.649611f), TypeParam(-0.720973f), TypeParam(-0.132725f), TypeParam(0.164074f), TypeParam(-0.698360f), TypeParam(0.653669f), TypeParam(-0.844065f), TypeParam(0.294728f), TypeParam(0.128341f), TypeParam(0.440293f), TypeParam(-1.177701f), TypeParam(0.069319f), TypeParam(0.585007f), TypeParam(-0.768260f), TypeParam(0.296941f), TypeParam(0.004702f), TypeParam(1.018020f), TypeParam(-0.254096f), TypeParam(0.008198f), TypeParam(-0.521925f), TypeParam(-0.295744f), TypeParam(0.343532f), TypeParam(-1.157334f), TypeParam(0.910329f), TypeParam(0.862921f), TypeParam(0.508195f), TypeParam(0.898317f), TypeParam(-0.373544f), TypeParam(0.273330f), TypeParam(0.061050f), TypeParam(-0.829794f), TypeParam(-0.461335f), TypeParam(-0.426012f), TypeParam(-0.296704f), TypeParam(-1.065526f), TypeParam(-0.843948f), TypeParam(-0.113955f), TypeParam(-0.182548f), TypeParam(-1.089296f), TypeParam(0.256401f), TypeParam(0.653393f), TypeParam(0.999377f), TypeParam(1.009925f), TypeParam(-0.838519f), TypeParam(-0.384579f), TypeParam(-0.569276f), TypeParam(0.220093f), TypeParam(0.321562f), TypeParam(0.266984f), TypeParam(0.701244f), TypeParam(0.633093f), TypeParam(-0.644096f), TypeParam(0.823778f), TypeParam(0.809482f), TypeParam(0.158802f), TypeParam(-1.044029f), TypeParam(-0.735991f), TypeParam(0.334411f), TypeParam(0.414891f), TypeParam(1.118940f), TypeParam(0.610743f), TypeParam(0.434932f), TypeParam(-0.040928f)}; std::initializer_list Y_shape{2, 2, 3, 3, 2}; - std::initializer_list Y_data{0.222880f, -0.137918f, 0.042779f, 0.027606f, 0.146833f, 0.119531f, 0.062001f, 0.077615f, -0.124874f, -0.020856f, 0.248748f, -0.050235f, -0.185885f, -0.124030f, -0.148987f, -0.345107f, 0.753440f, -0.055873f, 0.674388f, 0.063018f, -0.054480f, -0.034452f, 0.780917f, 0.193151f, -0.140647f, -0.047364f, -0.095816f, -0.046983f, 0.254384f, -0.123703f, 0.191358f, 0.674903f, -0.311971f, 1.032054f, 0.672506f, 0.009147f, 0.281933f, 0.135835f, -0.145082f, -0.392560f, -0.229593f, -0.632284f, -0.936929f, -0.916689f, -0.502247f, -0.108609f, -0.645451f, 0.242939f, -0.165902f, -1.220095f, -0.015084f, -0.300940f, -0.352557f, -0.886474f, 0.109150f, 0.398365f, 0.235757f, 0.358618f, 0.082189f, 0.268617f, 0.077955f, -0.157573f, 0.023048f, -0.346908f, 0.360128f, 0.389098f, 0.122882f, 0.675956f, 0.735857f, 0.354858f, 0.244544f, 0.631102f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(0.222880f), TypeParam(-0.137918f), TypeParam(0.042779f), TypeParam(0.027606f), TypeParam(0.146833f), TypeParam(0.119531f), TypeParam(0.062001f), TypeParam(0.077615f), TypeParam(-0.124874f), TypeParam(-0.020856f), TypeParam(0.248748f), TypeParam(-0.050235f), TypeParam(-0.185885f), TypeParam(-0.124030f), TypeParam(-0.148987f), TypeParam(-0.345107f), TypeParam(0.753440f), TypeParam(-0.055873f), TypeParam(0.674388f), TypeParam(0.063018f), TypeParam(-0.054480f), TypeParam(-0.034452f), TypeParam(0.780917f), TypeParam(0.193151f), TypeParam(-0.140647f), TypeParam(-0.047364f), TypeParam(-0.095816f), TypeParam(-0.046983f), TypeParam(0.254384f), TypeParam(-0.123703f), TypeParam(0.191358f), TypeParam(0.674903f), TypeParam(-0.311971f), TypeParam(1.032054f), TypeParam(0.672506f), TypeParam(0.009147f), TypeParam(0.281933f), TypeParam(0.135835f), TypeParam(-0.145082f), TypeParam(-0.392560f), TypeParam(-0.229593f), TypeParam(-0.632284f), TypeParam(-0.936929f), TypeParam(-0.916689f), TypeParam(-0.502247f), TypeParam(-0.108609f), TypeParam(-0.645451f), TypeParam(0.242939f), TypeParam(-0.165902f), TypeParam(-1.220095f), TypeParam(-0.015084f), TypeParam(-0.300940f), TypeParam(-0.352557f), TypeParam(-0.886474f), TypeParam(0.109150f), TypeParam(0.398365f), TypeParam(0.235757f), TypeParam(0.358618f), TypeParam(0.082189f), TypeParam(0.268617f), TypeParam(0.077955f), TypeParam(-0.157573f), TypeParam(0.023048f), TypeParam(-0.346908f), TypeParam(0.360128f), TypeParam(0.389098f), TypeParam(0.122882f), TypeParam(0.675956f), TypeParam(0.735857f), TypeParam(0.354858f), TypeParam(0.244544f), TypeParam(0.631102f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_border_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_border_align_corners) { OpTester test("GridSample", 20); std::string mode = "linear"; std::string padding_mode = "border"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-1.916003f, 0.150784f, -0.179898f, 0.402727f, -0.549764f, 1.772484f, 1.014343f, 0.502823f, 0.976771f, -0.071957f, 0.519875f, 0.408665f, 1.435640f, -0.807775f, -0.181661f, -0.574026f, -0.335351f, -0.155602f, 0.348749f, 1.055618f, 0.737784f, -0.394725f, 0.597608f, 0.006105f}; + std::initializer_list X_data{TypeParam(-1.916003f), TypeParam(0.150784f), TypeParam(-0.179898f), TypeParam(0.402727f), TypeParam(-0.549764f), TypeParam(1.772484f), TypeParam(1.014343f), TypeParam(0.502823f), TypeParam(0.976771f), TypeParam(-0.071957f), TypeParam(0.519875f), TypeParam(0.408665f), TypeParam(1.435640f), TypeParam(-0.807775f), TypeParam(-0.181661f), TypeParam(-0.574026f), TypeParam(-0.335351f), TypeParam(-0.155602f), TypeParam(0.348749f), TypeParam(1.055618f), TypeParam(0.737784f), TypeParam(-0.394725f), TypeParam(0.597608f), TypeParam(0.006105f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{-0.189838f, -1.050410f, -1.072351f, -0.930754f, -0.502573f, 0.186642f, -0.564332f, -0.042774f, -0.143740f, 1.097448f, -0.547044f, 1.127440f, -0.921224f, -1.001202f, 0.390232f, -0.698394f, 0.615509f, -0.663897f, 0.944958f, 1.161950f, 0.076823f, 0.256464f, 1.118784f, 0.711380f}; + std::initializer_list Grid_data{TypeParam(-0.189838f), TypeParam(-1.050410f), TypeParam(-1.072351f), TypeParam(-0.930754f), TypeParam(-0.502573f), TypeParam(0.186642f), TypeParam(-0.564332f), TypeParam(-0.042774f), TypeParam(-0.143740f), TypeParam(1.097448f), TypeParam(-0.547044f), TypeParam(1.127440f), TypeParam(-0.921224f), TypeParam(-1.001202f), TypeParam(0.390232f), TypeParam(-0.698394f), TypeParam(0.615509f), TypeParam(-0.663897f), TypeParam(0.944958f), TypeParam(1.161950f), TypeParam(0.076823f), TypeParam(0.256464f), TypeParam(1.118784f), TypeParam(0.711380f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-1.078787f, -1.795786f, -0.023270f, -0.113413f, 0.444460f, -0.023826f, 0.807136f, 1.011742f, 0.674182f, 0.754935f, 0.472262f, 0.494688f, 1.347277f, -0.223507f, -0.417529f, -0.160549f, -0.353331f, -0.276367f, 0.376591f, 0.571813f, 0.551111f, 0.022384f, 0.166782f, -0.109583f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-1.078787f), TypeParam(-1.795786f), TypeParam(-0.023270f), TypeParam(-0.113413f), TypeParam(0.444460f), TypeParam(-0.023826f), TypeParam(0.807136f), TypeParam(1.011742f), TypeParam(0.674182f), TypeParam(0.754935f), TypeParam(0.472262f), TypeParam(0.494688f), TypeParam(1.347277f), TypeParam(-0.223507f), TypeParam(-0.417529f), TypeParam(-0.160549f), TypeParam(-0.353331f), TypeParam(-0.276367f), TypeParam(0.376591f), TypeParam(0.571813f), TypeParam(0.551111f), TypeParam(0.022384f), TypeParam(0.166782f), TypeParam(-0.109583f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_border_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_border_align_corners) { OpTester test("GridSample", 20); std::string mode = "linear"; std::string padding_mode = "border"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 3, 2}; - std::initializer_list X_data{-0.332555f, 0.980958f, 0.002632f, -1.976749f, 0.979548f, 1.109773f, -0.534887f, 0.705692f, -0.143637f, -0.600830f, 0.315853f, -0.604687f, -0.300652f, -0.375240f, 0.377196f, -0.140920f, 1.159946f, 2.364598f, 0.320719f, 0.397938f, -0.680097f, -1.201632f, 0.270077f, -0.036712f, -0.972864f, 0.792393f, -1.159168f, -0.016679f, -0.665027f, 0.809646f, -1.684452f, 0.049476f, 0.065748f, 0.279619f, -1.079668f, 0.301309f, 1.010100f, -0.119015f, -0.104838f, 0.916627f, -0.522838f, 0.485269f, -1.221088f, 2.044754f, -0.669823f, 0.128370f, 0.080480f, 0.372679f, -0.046427f, -0.732652f, -0.395790f, 0.012594f, -0.170518f, -0.706783f, -0.862588f, -1.177275f, -1.165262f, 0.914826f, -0.661128f, -0.386656f, -0.599246f, 0.544643f, 0.930679f, -1.146137f, 0.212913f, -0.022433f, 1.692830f, 0.187511f, -0.631569f, -0.311540f, -0.885167f, -0.429959f}; + std::initializer_list X_data{TypeParam(-0.332555f), TypeParam(0.980958f), TypeParam(0.002632f), TypeParam(-1.976749f), TypeParam(0.979548f), TypeParam(1.109773f), TypeParam(-0.534887f), TypeParam(0.705692f), TypeParam(-0.143637f), TypeParam(-0.600830f), TypeParam(0.315853f), TypeParam(-0.604687f), TypeParam(-0.300652f), TypeParam(-0.375240f), TypeParam(0.377196f), TypeParam(-0.140920f), TypeParam(1.159946f), TypeParam(2.364598f), TypeParam(0.320719f), TypeParam(0.397938f), TypeParam(-0.680097f), TypeParam(-1.201632f), TypeParam(0.270077f), TypeParam(-0.036712f), TypeParam(-0.972864f), TypeParam(0.792393f), TypeParam(-1.159168f), TypeParam(-0.016679f), TypeParam(-0.665027f), TypeParam(0.809646f), TypeParam(-1.684452f), TypeParam(0.049476f), TypeParam(0.065748f), TypeParam(0.279619f), TypeParam(-1.079668f), TypeParam(0.301309f), TypeParam(1.010100f), TypeParam(-0.119015f), TypeParam(-0.104838f), TypeParam(0.916627f), TypeParam(-0.522838f), TypeParam(0.485269f), TypeParam(-1.221088f), TypeParam(2.044754f), TypeParam(-0.669823f), TypeParam(0.128370f), TypeParam(0.080480f), TypeParam(0.372679f), TypeParam(-0.046427f), TypeParam(-0.732652f), TypeParam(-0.395790f), TypeParam(0.012594f), TypeParam(-0.170518f), TypeParam(-0.706783f), TypeParam(-0.862588f), TypeParam(-1.177275f), TypeParam(-1.165262f), TypeParam(0.914826f), TypeParam(-0.661128f), TypeParam(-0.386656f), TypeParam(-0.599246f), TypeParam(0.544643f), TypeParam(0.930679f), TypeParam(-1.146137f), TypeParam(0.212913f), TypeParam(-0.022433f), TypeParam(1.692830f), TypeParam(0.187511f), TypeParam(-0.631569f), TypeParam(-0.311540f), TypeParam(-0.885167f), TypeParam(-0.429959f)}; std::initializer_list Grid_shape{2, 3, 3, 2, 3}; - std::initializer_list Grid_data{-0.453992f, 0.394222f, 0.755023f, -0.025610f, 0.658840f, 0.982105f, -0.642922f, -0.265292f, -1.080379f, 0.275464f, 0.855228f, -0.233029f, 0.191483f, 0.383441f, -0.025595f, 0.932929f, 0.174866f, -1.179535f, -0.990943f, -1.188918f, 0.049460f, 0.648682f, -0.158317f, 1.078936f, -0.215883f, 0.245340f, 1.082089f, 0.607310f, -0.038283f, 1.155868f, -0.716957f, 0.446971f, 0.757844f, -0.743030f, -1.127212f, 0.383835f, -0.455267f, -0.605570f, 0.238686f, -0.870514f, 1.079285f, -0.107719f, -0.384303f, 1.003178f, 0.334130f, 0.228627f, -0.573757f, 1.143690f, -0.365482f, 0.998076f, -0.088210f, 0.601965f, 0.843747f, -0.893403f, -0.799804f, -1.186625f, 0.865515f, 1.031983f, -0.438564f, -0.587735f, 0.200868f, 0.646055f, 0.296203f, -0.250092f, -0.763290f, 1.026321f, -0.777136f, -1.159559f, -0.479127f, 0.239290f, 0.446029f, 0.464001f, -0.695158f, -0.460548f, -0.533616f, -0.581111f, -1.010728f, 0.245640f, -0.348981f, -1.155007f, -0.700701f, -0.720655f, -0.517635f, -0.741485f, -0.208103f, 0.430035f, -0.971177f, -0.102798f, -0.345348f, -0.613510f, -0.266458f, -0.508597f, 0.038577f, -0.866220f, 0.227567f, 1.101759f, 0.994334f, -0.538031f, 0.369874f, -1.134245f, 1.010332f, -1.195878f, -1.072351f, -1.077155f, -1.114385f, 0.162516f, -0.317319f, 0.287217f}; + std::initializer_list Grid_data{TypeParam(-0.453992f), TypeParam(0.394222f), TypeParam(0.755023f), TypeParam(-0.025610f), TypeParam(0.658840f), TypeParam(0.982105f), TypeParam(-0.642922f), TypeParam(-0.265292f), TypeParam(-1.080379f), TypeParam(0.275464f), TypeParam(0.855228f), TypeParam(-0.233029f), TypeParam(0.191483f), TypeParam(0.383441f), TypeParam(-0.025595f), TypeParam(0.932929f), TypeParam(0.174866f), TypeParam(-1.179535f), TypeParam(-0.990943f), TypeParam(-1.188918f), TypeParam(0.049460f), TypeParam(0.648682f), TypeParam(-0.158317f), TypeParam(1.078936f), TypeParam(-0.215883f), TypeParam(0.245340f), TypeParam(1.082089f), TypeParam(0.607310f), TypeParam(-0.038283f), TypeParam(1.155868f), TypeParam(-0.716957f), TypeParam(0.446971f), TypeParam(0.757844f), TypeParam(-0.743030f), TypeParam(-1.127212f), TypeParam(0.383835f), TypeParam(-0.455267f), TypeParam(-0.605570f), TypeParam(0.238686f), TypeParam(-0.870514f), TypeParam(1.079285f), TypeParam(-0.107719f), TypeParam(-0.384303f), TypeParam(1.003178f), TypeParam(0.334130f), TypeParam(0.228627f), TypeParam(-0.573757f), TypeParam(1.143690f), TypeParam(-0.365482f), TypeParam(0.998076f), TypeParam(-0.088210f), TypeParam(0.601965f), TypeParam(0.843747f), TypeParam(-0.893403f), TypeParam(-0.799804f), TypeParam(-1.186625f), TypeParam(0.865515f), TypeParam(1.031983f), TypeParam(-0.438564f), TypeParam(-0.587735f), TypeParam(0.200868f), TypeParam(0.646055f), TypeParam(0.296203f), TypeParam(-0.250092f), TypeParam(-0.763290f), TypeParam(1.026321f), TypeParam(-0.777136f), TypeParam(-1.159559f), TypeParam(-0.479127f), TypeParam(0.239290f), TypeParam(0.446029f), TypeParam(0.464001f), TypeParam(-0.695158f), TypeParam(-0.460548f), TypeParam(-0.533616f), TypeParam(-0.581111f), TypeParam(-1.010728f), TypeParam(0.245640f), TypeParam(-0.348981f), TypeParam(-1.155007f), TypeParam(-0.700701f), TypeParam(-0.720655f), TypeParam(-0.517635f), TypeParam(-0.741485f), TypeParam(-0.208103f), TypeParam(0.430035f), TypeParam(-0.971177f), TypeParam(-0.102798f), TypeParam(-0.345348f), TypeParam(-0.613510f), TypeParam(-0.266458f), TypeParam(-0.508597f), TypeParam(0.038577f), TypeParam(-0.866220f), TypeParam(0.227567f), TypeParam(1.101759f), TypeParam(0.994334f), TypeParam(-0.538031f), TypeParam(0.369874f), TypeParam(-1.134245f), TypeParam(1.010332f), TypeParam(-1.195878f), TypeParam(-1.072351f), TypeParam(-1.077155f), TypeParam(-1.114385f), TypeParam(0.162516f), TypeParam(-0.317319f), TypeParam(0.287217f)}; std::initializer_list Y_shape{2, 2, 3, 3, 2}; - std::initializer_list Y_data{0.517362f, 1.168304f, -0.283719f, -0.056944f, -0.345007f, -1.383013f, -0.517978f, -0.099340f, 0.531814f, -0.051495f, 0.570203f, -0.350444f, -0.195512f, 0.335075f, 0.533103f, -0.173681f, 0.110927f, 0.549661f, -0.303447f, -0.209369f, -0.479343f, 0.113517f, -0.222508f, -0.981697f, -1.000072f, 0.163343f, -0.019158f, 0.217390f, -0.442252f, -1.020732f, -0.645033f, -0.481248f, -0.359233f, -0.271288f, -0.165768f, -0.092544f, -0.219889f, 0.671201f, -0.041137f, -0.289275f, -0.022793f, -0.130253f, -0.072692f, -0.451858f, 0.402947f, 0.168711f, 0.110811f, 0.202315f, -0.200036f, -0.331588f, 0.583341f, -0.522838f, 1.010100f, -0.018650f, 1.269564f, -0.168394f, -0.209390f, 0.740205f, -0.675828f, -0.325915f, -0.404694f, 0.067064f, -0.744102f, -0.639736f, -0.416580f, -0.317643f, 0.004590f, -0.665815f, -0.163600f, -0.661128f, -0.862588f, -0.132515f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(0.517362f), TypeParam(1.168304f), TypeParam(-0.283719f), TypeParam(-0.056944f), TypeParam(-0.345007f), TypeParam(-1.383013f), TypeParam(-0.517978f), TypeParam(-0.099340f), TypeParam(0.531814f), TypeParam(-0.051495f), TypeParam(0.570203f), TypeParam(-0.350444f), TypeParam(-0.195512f), TypeParam(0.335075f), TypeParam(0.533103f), TypeParam(-0.173681f), TypeParam(0.110927f), TypeParam(0.549661f), TypeParam(-0.303447f), TypeParam(-0.209369f), TypeParam(-0.479343f), TypeParam(0.113517f), TypeParam(-0.222508f), TypeParam(-0.981697f), TypeParam(-1.000072f), TypeParam(0.163343f), TypeParam(-0.019158f), TypeParam(0.217390f), TypeParam(-0.442252f), TypeParam(-1.020732f), TypeParam(-0.645033f), TypeParam(-0.481248f), TypeParam(-0.359233f), TypeParam(-0.271288f), TypeParam(-0.165768f), TypeParam(-0.092544f), TypeParam(-0.219889f), TypeParam(0.671201f), TypeParam(-0.041137f), TypeParam(-0.289275f), TypeParam(-0.022793f), TypeParam(-0.130253f), TypeParam(-0.072692f), TypeParam(-0.451858f), TypeParam(0.402947f), TypeParam(0.168711f), TypeParam(0.110811f), TypeParam(0.202315f), TypeParam(-0.200036f), TypeParam(-0.331588f), TypeParam(0.583341f), TypeParam(-0.522838f), TypeParam(1.010100f), TypeParam(-0.018650f), TypeParam(1.269564f), TypeParam(-0.168394f), TypeParam(-0.209390f), TypeParam(0.740205f), TypeParam(-0.675828f), TypeParam(-0.325915f), TypeParam(-0.404694f), TypeParam(0.067064f), TypeParam(-0.744102f), TypeParam(-0.639736f), TypeParam(-0.416580f), TypeParam(-0.317643f), TypeParam(0.004590f), TypeParam(-0.665815f), TypeParam(-0.163600f), TypeParam(-0.661128f), TypeParam(-0.862588f), TypeParam(-0.132515f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_border_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_border_no_align_corners) { OpTester test("GridSample", 20); std::string mode = "linear"; std::string padding_mode = "border"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-0.050553f, -0.825690f, -0.616085f, 0.337113f, 0.370334f, -0.105073f, -0.565382f, 0.396842f, -0.373193f, -0.780451f, -1.932970f, 1.104960f, -2.569945f, 0.661190f, -0.192302f, 0.734279f, 0.351872f, -1.068136f, 0.173665f, -0.778153f, -0.981877f, 1.485344f, 0.431733f, 0.428167f}; + std::initializer_list X_data{TypeParam(-0.050553f), TypeParam(-0.825690f), TypeParam(-0.616085f), TypeParam(0.337113f), TypeParam(0.370334f), TypeParam(-0.105073f), TypeParam(-0.565382f), TypeParam(0.396842f), TypeParam(-0.373193f), TypeParam(-0.780451f), TypeParam(-1.932970f), TypeParam(1.104960f), TypeParam(-2.569945f), TypeParam(0.661190f), TypeParam(-0.192302f), TypeParam(0.734279f), TypeParam(0.351872f), TypeParam(-1.068136f), TypeParam(0.173665f), TypeParam(-0.778153f), TypeParam(-0.981877f), TypeParam(1.485344f), TypeParam(0.431733f), TypeParam(0.428167f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{-0.330875f, 0.589988f, 0.011588f, -1.144325f, -1.038357f, 0.435055f, -1.053243f, -0.957144f, -0.715458f, 1.143742f, -0.341215f, -0.494762f, -0.810255f, 0.767649f, -0.193763f, 0.231402f, 0.286668f, 0.338432f, 0.768106f, 0.062272f, 0.124125f, -0.077928f, -0.932481f, -0.274618f}; + std::initializer_list Grid_data{TypeParam(-0.330875f), TypeParam(0.589988f), TypeParam(0.011588f), TypeParam(-1.144325f), TypeParam(-1.038357f), TypeParam(0.435055f), TypeParam(-1.053243f), TypeParam(-0.957144f), TypeParam(-0.715458f), TypeParam(1.143742f), TypeParam(-0.341215f), TypeParam(-0.494762f), TypeParam(-0.810255f), TypeParam(0.767649f), TypeParam(-0.193763f), TypeParam(0.231402f), TypeParam(0.286668f), TypeParam(0.338432f), TypeParam(0.768106f), TypeParam(0.062272f), TypeParam(0.124125f), TypeParam(-0.077928f), TypeParam(-0.932481f), TypeParam(-0.274618f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{0.204265f, -0.447104f, 0.027635f, -0.050553f, 0.370334f, -0.248695f, -1.306797f, -0.073120f, -1.391077f, -0.565382f, -1.932970f, -0.419110f, 0.351872f, 0.030903f, -0.124253f, 0.565919f, 0.276202f, -1.171718f, 0.431733f, 0.001712f, 0.689913f, 1.386595f, 0.443614f, -0.505878f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(0.204265f), TypeParam(-0.447104f), TypeParam(0.027635f), TypeParam(-0.050553f), TypeParam(0.370334f), TypeParam(-0.248695f), TypeParam(-1.306797f), TypeParam(-0.073120f), TypeParam(-1.391077f), TypeParam(-0.565382f), TypeParam(-1.932970f), TypeParam(-0.419110f), TypeParam(0.351872f), TypeParam(0.030903f), TypeParam(-0.124253f), TypeParam(0.565919f), TypeParam(0.276202f), TypeParam(-1.171718f), TypeParam(0.431733f), TypeParam(0.001712f), TypeParam(0.689913f), TypeParam(1.386595f), TypeParam(0.443614f), TypeParam(-0.505878f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_border_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_border_no_align_corners) { OpTester test("GridSample", 20); std::string mode = "linear"; std::string padding_mode = "border"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 3, 2}; - std::initializer_list X_data{-0.727099f, 0.057663f, -0.548384f, 0.078163f, -0.133679f, 0.211872f, 0.271687f, -1.221973f, -2.630687f, -0.558102f, -0.327183f, 0.039894f, 1.222102f, 0.144418f, 0.696676f, -2.231791f, 0.910544f, 2.749837f, -0.354036f, -0.106102f, 2.453576f, 0.332319f, -1.743712f, 1.416859f, 0.260041f, -1.179930f, 0.407328f, 0.375476f, 2.028488f, 0.174825f, -1.467126f, 0.079045f, 0.870076f, -0.895165f, 0.631429f, 0.358222f, 1.484120f, -0.622331f, 0.727481f, 0.644213f, 1.299103f, -0.378573f, 1.360908f, 0.905514f, 0.180065f, 0.972162f, 1.246238f, -0.537204f, -1.241497f, -0.772822f, -0.149044f, -1.642060f, 0.120091f, 0.937023f, 0.422106f, 0.652040f, 0.045585f, -1.089530f, 0.356099f, 0.536075f, -1.840257f, -1.035736f, 0.348653f, 0.187942f, 0.150011f, 0.521798f, 1.271739f, 0.977495f, 0.811927f, 0.641729f, 0.964401f, -0.693074f}; + std::initializer_list X_data{TypeParam(-0.727099f), TypeParam(0.057663f), TypeParam(-0.548384f), TypeParam(0.078163f), TypeParam(-0.133679f), TypeParam(0.211872f), TypeParam(0.271687f), TypeParam(-1.221973f), TypeParam(-2.630687f), TypeParam(-0.558102f), TypeParam(-0.327183f), TypeParam(0.039894f), TypeParam(1.222102f), TypeParam(0.144418f), TypeParam(0.696676f), TypeParam(-2.231791f), TypeParam(0.910544f), TypeParam(2.749837f), TypeParam(-0.354036f), TypeParam(-0.106102f), TypeParam(2.453576f), TypeParam(0.332319f), TypeParam(-1.743712f), TypeParam(1.416859f), TypeParam(0.260041f), TypeParam(-1.179930f), TypeParam(0.407328f), TypeParam(0.375476f), TypeParam(2.028488f), TypeParam(0.174825f), TypeParam(-1.467126f), TypeParam(0.079045f), TypeParam(0.870076f), TypeParam(-0.895165f), TypeParam(0.631429f), TypeParam(0.358222f), TypeParam(1.484120f), TypeParam(-0.622331f), TypeParam(0.727481f), TypeParam(0.644213f), TypeParam(1.299103f), TypeParam(-0.378573f), TypeParam(1.360908f), TypeParam(0.905514f), TypeParam(0.180065f), TypeParam(0.972162f), TypeParam(1.246238f), TypeParam(-0.537204f), TypeParam(-1.241497f), TypeParam(-0.772822f), TypeParam(-0.149044f), TypeParam(-1.642060f), TypeParam(0.120091f), TypeParam(0.937023f), TypeParam(0.422106f), TypeParam(0.652040f), TypeParam(0.045585f), TypeParam(-1.089530f), TypeParam(0.356099f), TypeParam(0.536075f), TypeParam(-1.840257f), TypeParam(-1.035736f), TypeParam(0.348653f), TypeParam(0.187942f), TypeParam(0.150011f), TypeParam(0.521798f), TypeParam(1.271739f), TypeParam(0.977495f), TypeParam(0.811927f), TypeParam(0.641729f), TypeParam(0.964401f), TypeParam(-0.693074f)}; std::initializer_list Grid_shape{2, 3, 3, 2, 3}; - std::initializer_list Grid_data{1.017692f, -0.818194f, 0.525611f, -0.556812f, -0.124601f, 1.120205f, 0.153552f, -1.144168f, 1.103147f, -0.050771f, -0.600881f, -0.633732f, 1.029039f, 0.020253f, 0.662802f, 0.788674f, -0.465758f, 0.101853f, -0.776226f, 1.002064f, -0.634553f, 0.797064f, 0.304043f, 0.740241f, -0.845484f, -0.037319f, 0.621792f, -0.047898f, -0.017218f, 0.584766f, -0.896882f, -0.240587f, 0.546590f, 0.588539f, 1.114539f, -0.237379f, 0.284327f, -0.590432f, -0.201402f, -0.602420f, 0.889284f, 0.007310f, 0.488176f, 0.660055f, 0.223618f, 0.127703f, -0.087830f, -1.016490f, 0.193341f, -0.265853f, -1.008634f, 1.118021f, -0.127930f, -0.598904f, -1.168221f, -1.105256f, 0.456964f, -0.547805f, -0.518368f, -0.694346f, 0.968648f, -0.288466f, 0.777819f, 0.952657f, -0.930362f, 0.895254f, -0.229149f, 1.149323f, 0.612939f, -1.162419f, 0.222934f, 0.421831f, -0.435327f, 0.909973f, -0.993750f, -0.380767f, 1.143396f, 1.171977f, 0.599451f, -0.716336f, -1.032482f, -0.975683f, -0.299985f, 0.679795f, 0.379920f, -0.145729f, 1.079221f, 0.942322f, -0.560859f, -0.519668f, -0.014079f, 0.249021f, -0.008590f, 0.463277f, 0.827937f, -0.216375f, 0.589310f, 0.163207f, 0.460623f, 0.494016f, -0.320739f, -0.535032f, 0.512922f, -0.768302f, 0.630003f, -0.769945f, 0.823242f, 0.481487f}; + std::initializer_list Grid_data{TypeParam(1.017692f), TypeParam(-0.818194f), TypeParam(0.525611f), TypeParam(-0.556812f), TypeParam(-0.124601f), TypeParam(1.120205f), TypeParam(0.153552f), TypeParam(-1.144168f), TypeParam(1.103147f), TypeParam(-0.050771f), TypeParam(-0.600881f), TypeParam(-0.633732f), TypeParam(1.029039f), TypeParam(0.020253f), TypeParam(0.662802f), TypeParam(0.788674f), TypeParam(-0.465758f), TypeParam(0.101853f), TypeParam(-0.776226f), TypeParam(1.002064f), TypeParam(-0.634553f), TypeParam(0.797064f), TypeParam(0.304043f), TypeParam(0.740241f), TypeParam(-0.845484f), TypeParam(-0.037319f), TypeParam(0.621792f), TypeParam(-0.047898f), TypeParam(-0.017218f), TypeParam(0.584766f), TypeParam(-0.896882f), TypeParam(-0.240587f), TypeParam(0.546590f), TypeParam(0.588539f), TypeParam(1.114539f), TypeParam(-0.237379f), TypeParam(0.284327f), TypeParam(-0.590432f), TypeParam(-0.201402f), TypeParam(-0.602420f), TypeParam(0.889284f), TypeParam(0.007310f), TypeParam(0.488176f), TypeParam(0.660055f), TypeParam(0.223618f), TypeParam(0.127703f), TypeParam(-0.087830f), TypeParam(-1.016490f), TypeParam(0.193341f), TypeParam(-0.265853f), TypeParam(-1.008634f), TypeParam(1.118021f), TypeParam(-0.127930f), TypeParam(-0.598904f), TypeParam(-1.168221f), TypeParam(-1.105256f), TypeParam(0.456964f), TypeParam(-0.547805f), TypeParam(-0.518368f), TypeParam(-0.694346f), TypeParam(0.968648f), TypeParam(-0.288466f), TypeParam(0.777819f), TypeParam(0.952657f), TypeParam(-0.930362f), TypeParam(0.895254f), TypeParam(-0.229149f), TypeParam(1.149323f), TypeParam(0.612939f), TypeParam(-1.162419f), TypeParam(0.222934f), TypeParam(0.421831f), TypeParam(-0.435327f), TypeParam(0.909973f), TypeParam(-0.993750f), TypeParam(-0.380767f), TypeParam(1.143396f), TypeParam(1.171977f), TypeParam(0.599451f), TypeParam(-0.716336f), TypeParam(-1.032482f), TypeParam(-0.975683f), TypeParam(-0.299985f), TypeParam(0.679795f), TypeParam(0.379920f), TypeParam(-0.145729f), TypeParam(1.079221f), TypeParam(0.942322f), TypeParam(-0.560859f), TypeParam(-0.519668f), TypeParam(-0.014079f), TypeParam(0.249021f), TypeParam(-0.008590f), TypeParam(0.463277f), TypeParam(0.827937f), TypeParam(-0.216375f), TypeParam(0.589310f), TypeParam(0.163207f), TypeParam(0.460623f), TypeParam(0.494016f), TypeParam(-0.320739f), TypeParam(-0.535032f), TypeParam(0.512922f), TypeParam(-0.768302f), TypeParam(0.630003f), TypeParam(-0.769945f), TypeParam(0.823242f), TypeParam(0.481487f)}; std::initializer_list Y_shape{2, 2, 3, 3, 2}; - std::initializer_list Y_data{-0.144687f, 0.794879f, 0.517780f, -0.372025f, -2.071523f, -0.953122f, -0.143000f, 0.040151f, 0.511071f, -0.723342f, 0.441486f, 0.101130f, -0.668215f, -0.313612f, 0.918245f, -0.165560f, -0.141496f, -0.002992f, -0.187333f, 0.433250f, -0.456623f, -0.082449f, -0.849978f, -0.635311f, -1.562003f, -0.323540f, 0.716348f, 0.089914f, 0.085623f, 0.617075f, -0.522245f, 2.013170f, 0.249061f, 0.948093f, 0.518262f, 0.230788f, -0.422900f, 1.315807f, -1.265941f, -0.772822f, 0.375354f, 0.159706f, 1.190603f, 0.217497f, -0.622331f, -0.640623f, -1.324261f, -0.126419f, 0.497220f, -0.421485f, -0.512049f, 0.218454f, -0.680520f, 0.432900f, 0.292848f, 0.338349f, 0.787015f, 0.977495f, 0.494135f, 0.649655f, 0.367739f, 0.766775f, 0.652040f, 1.018832f, 0.738819f, 0.107251f, 0.287288f, 0.515065f, 0.300961f, -0.279154f, 0.866776f, 0.738188f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.144687f), TypeParam(0.794879f), TypeParam(0.517780f), TypeParam(-0.372025f), TypeParam(-2.071523f), TypeParam(-0.953122f), TypeParam(-0.143000f), TypeParam(0.040151f), TypeParam(0.511071f), TypeParam(-0.723342f), TypeParam(0.441486f), TypeParam(0.101130f), TypeParam(-0.668215f), TypeParam(-0.313612f), TypeParam(0.918245f), TypeParam(-0.165560f), TypeParam(-0.141496f), TypeParam(-0.002992f), TypeParam(-0.187333f), TypeParam(0.433250f), TypeParam(-0.456623f), TypeParam(-0.082449f), TypeParam(-0.849978f), TypeParam(-0.635311f), TypeParam(-1.562003f), TypeParam(-0.323540f), TypeParam(0.716348f), TypeParam(0.089914f), TypeParam(0.085623f), TypeParam(0.617075f), TypeParam(-0.522245f), TypeParam(2.013170f), TypeParam(0.249061f), TypeParam(0.948093f), TypeParam(0.518262f), TypeParam(0.230788f), TypeParam(-0.422900f), TypeParam(1.315807f), TypeParam(-1.265941f), TypeParam(-0.772822f), TypeParam(0.375354f), TypeParam(0.159706f), TypeParam(1.190603f), TypeParam(0.217497f), TypeParam(-0.622331f), TypeParam(-0.640623f), TypeParam(-1.324261f), TypeParam(-0.126419f), TypeParam(0.497220f), TypeParam(-0.421485f), TypeParam(-0.512049f), TypeParam(0.218454f), TypeParam(-0.680520f), TypeParam(0.432900f), TypeParam(0.292848f), TypeParam(0.338349f), TypeParam(0.787015f), TypeParam(0.977495f), TypeParam(0.494135f), TypeParam(0.649655f), TypeParam(0.367739f), TypeParam(0.766775f), TypeParam(0.652040f), TypeParam(1.018832f), TypeParam(0.738819f), TypeParam(0.107251f), TypeParam(0.287288f), TypeParam(0.515065f), TypeParam(0.300961f), TypeParam(-0.279154f), TypeParam(0.866776f), TypeParam(0.738188f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_reflection_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_reflection_align_corners) { OpTester test("GridSample", 20); std::string mode = "linear"; std::string padding_mode = "reflection"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-0.599439f, 0.317612f, -0.294302f, -0.530613f, 0.754687f, 0.092241f, -1.009405f, -1.155944f, 0.336327f, 0.159353f, -1.134330f, 0.510271f, 0.271972f, 1.301884f, 1.027400f, 1.193876f, 0.304363f, 1.027256f, 0.186801f, 0.719412f, -0.310900f, -1.123812f, -0.312771f, 2.729156f}; + std::initializer_list X_data{TypeParam(-0.599439f), TypeParam(0.317612f), TypeParam(-0.294302f), TypeParam(-0.530613f), TypeParam(0.754687f), TypeParam(0.092241f), TypeParam(-1.009405f), TypeParam(-1.155944f), TypeParam(0.336327f), TypeParam(0.159353f), TypeParam(-1.134330f), TypeParam(0.510271f), TypeParam(0.271972f), TypeParam(1.301884f), TypeParam(1.027400f), TypeParam(1.193876f), TypeParam(0.304363f), TypeParam(1.027256f), TypeParam(0.186801f), TypeParam(0.719412f), TypeParam(-0.310900f), TypeParam(-1.123812f), TypeParam(-0.312771f), TypeParam(2.729156f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.853801f, 0.833200f, -0.477474f, 0.131677f, 0.571825f, 0.858708f, -1.120796f, 1.194690f, -0.301706f, 0.488934f, -0.745307f, -0.923452f, -0.812682f, 0.707226f, -0.591920f, 0.697573f, 0.362777f, 0.477332f, -0.266909f, -0.379588f, -0.561456f, -0.670762f, 1.106438f, -0.065215f}; + std::initializer_list Grid_data{TypeParam(0.853801f), TypeParam(0.833200f), TypeParam(-0.477474f), TypeParam(0.131677f), TypeParam(0.571825f), TypeParam(0.858708f), TypeParam(-1.120796f), TypeParam(1.194690f), TypeParam(-0.301706f), TypeParam(0.488934f), TypeParam(-0.745307f), TypeParam(-0.923452f), TypeParam(-0.812682f), TypeParam(0.707226f), TypeParam(-0.591920f), TypeParam(0.697573f), TypeParam(0.362777f), TypeParam(0.477332f), TypeParam(-0.266909f), TypeParam(-0.379588f), TypeParam(-0.561456f), TypeParam(-0.670762f), TypeParam(1.106438f), TypeParam(-0.065215f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{0.031577f, -0.232574f, 0.133168f, 0.515460f, 0.063332f, -0.470541f, 0.353729f, 0.159106f, 0.163701f, -0.770097f, -0.133556f, -0.925350f, 0.568498f, 0.636194f, 0.976680f, 0.921805f, 0.684184f, 1.189063f, -0.133022f, 0.070598f, 0.388079f, -0.232737f, 0.042589f, -0.965013f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(0.031577f), TypeParam(-0.232574f), TypeParam(0.133168f), TypeParam(0.515460f), TypeParam(0.063332f), TypeParam(-0.470541f), TypeParam(0.353729f), TypeParam(0.159106f), TypeParam(0.163701f), TypeParam(-0.770097f), TypeParam(-0.133556f), TypeParam(-0.925350f), TypeParam(0.568498f), TypeParam(0.636194f), TypeParam(0.976680f), TypeParam(0.921805f), TypeParam(0.684184f), TypeParam(1.189063f), TypeParam(-0.133022f), TypeParam(0.070598f), TypeParam(0.388079f), TypeParam(-0.232737f), TypeParam(0.042589f), TypeParam(-0.965013f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_reflection_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_reflection_align_corners) { OpTester test("GridSample", 20); std::string mode = "linear"; std::string padding_mode = "reflection"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 3, 2}; - std::initializer_list X_data{-0.441629f, 0.199148f, 1.214051f, -0.000869f, 0.863692f, -0.067719f, -0.621662f, 0.235179f, 0.691041f, 0.176564f, 0.036477f, -0.085879f, 0.785440f, -1.837889f, -0.300151f, -1.710413f, 0.484432f, 2.160478f, -0.049246f, 0.372475f, -1.060470f, -1.000841f, -0.473439f, 0.963055f, 0.174518f, 0.932434f, 0.039338f, -0.343549f, -1.446623f, -0.673622f, 0.520395f, -0.279228f, -0.367065f, -0.871085f, 0.649273f, -0.835047f, 1.063542f, -1.829784f, 1.476173f, -1.048210f, -1.127299f, 1.204756f, -0.998390f, -1.014054f, -1.032717f, 0.977184f, 0.959897f, -0.749289f, 0.784492f, 1.343993f, 1.291144f, 0.099496f, 2.086763f, 0.529948f, -2.296640f, 0.570701f, 0.491216f, -0.003836f, -0.591929f, -0.076994f, 1.239698f, -0.888840f, 0.623497f, 0.769879f, 2.240972f, -2.081689f, 0.798466f, 1.207944f, -0.486804f, -0.488222f, -0.746382f, -0.220282f}; + std::initializer_list X_data{TypeParam(-0.441629f), TypeParam(0.199148f), TypeParam(1.214051f), TypeParam(-0.000869f), TypeParam(0.863692f), TypeParam(-0.067719f), TypeParam(-0.621662f), TypeParam(0.235179f), TypeParam(0.691041f), TypeParam(0.176564f), TypeParam(0.036477f), TypeParam(-0.085879f), TypeParam(0.785440f), TypeParam(-1.837889f), TypeParam(-0.300151f), TypeParam(-1.710413f), TypeParam(0.484432f), TypeParam(2.160478f), TypeParam(-0.049246f), TypeParam(0.372475f), TypeParam(-1.060470f), TypeParam(-1.000841f), TypeParam(-0.473439f), TypeParam(0.963055f), TypeParam(0.174518f), TypeParam(0.932434f), TypeParam(0.039338f), TypeParam(-0.343549f), TypeParam(-1.446623f), TypeParam(-0.673622f), TypeParam(0.520395f), TypeParam(-0.279228f), TypeParam(-0.367065f), TypeParam(-0.871085f), TypeParam(0.649273f), TypeParam(-0.835047f), TypeParam(1.063542f), TypeParam(-1.829784f), TypeParam(1.476173f), TypeParam(-1.048210f), TypeParam(-1.127299f), TypeParam(1.204756f), TypeParam(-0.998390f), TypeParam(-1.014054f), TypeParam(-1.032717f), TypeParam(0.977184f), TypeParam(0.959897f), TypeParam(-0.749289f), TypeParam(0.784492f), TypeParam(1.343993f), TypeParam(1.291144f), TypeParam(0.099496f), TypeParam(2.086763f), TypeParam(0.529948f), TypeParam(-2.296640f), TypeParam(0.570701f), TypeParam(0.491216f), TypeParam(-0.003836f), TypeParam(-0.591929f), TypeParam(-0.076994f), TypeParam(1.239698f), TypeParam(-0.888840f), TypeParam(0.623497f), TypeParam(0.769879f), TypeParam(2.240972f), TypeParam(-2.081689f), TypeParam(0.798466f), TypeParam(1.207944f), TypeParam(-0.486804f), TypeParam(-0.488222f), TypeParam(-0.746382f), TypeParam(-0.220282f)}; std::initializer_list Grid_shape{2, 3, 3, 2, 3}; - std::initializer_list Grid_data{-0.169044f, 0.178997f, 1.112567f, -0.825642f, -0.359793f, 0.170758f, -0.081412f, 0.319486f, 0.630993f, -0.493702f, 0.093438f, 1.085657f, -0.679024f, -0.813753f, -0.920282f, 0.717311f, -1.100678f, -0.583561f, 0.810473f, -0.719377f, 0.975857f, -0.560957f, 0.189840f, 0.157082f, -0.029434f, 0.747413f, 1.019186f, -0.749235f, 0.673000f, 0.320624f, -0.022362f, -0.839050f, 0.355966f, 0.871005f, -1.030007f, -1.108265f, -1.179701f, 0.277273f, -0.344802f, -0.372753f, 1.117390f, -0.306079f, -0.762057f, 0.107942f, -0.658634f, -0.351593f, 0.633875f, 0.276953f, -0.823465f, 1.142446f, 0.811875f, -0.818022f, 0.522699f, 0.493103f, -0.861061f, -0.843352f, -0.993629f, 0.534540f, 0.209070f, 0.507143f, -0.527071f, 0.902309f, 0.153227f, -0.957513f, -0.302041f, 0.612404f, 0.263859f, -0.183579f, -0.838388f, -0.746482f, 1.035039f, -0.687403f, 0.850371f, -0.401659f, 0.011995f, -1.168548f, -0.390077f, 1.011575f, -1.077360f, 0.603794f, -1.009901f, 0.175023f, -1.087964f, -0.949961f, -0.968757f, -0.416100f, 0.163389f, -0.879807f, 0.304124f, 0.722748f, 0.978239f, 1.062535f, 0.790067f, -0.353356f, -0.110591f, 1.061730f, 0.596951f, -0.318231f, 0.905999f, -1.048710f, 1.027042f, 0.671407f, -0.880154f, -0.978736f, 0.938431f, 1.183815f, 0.104716f, -0.468883f}; + std::initializer_list Grid_data{TypeParam(-0.169044f), TypeParam(0.178997f), TypeParam(1.112567f), TypeParam(-0.825642f), TypeParam(-0.359793f), TypeParam(0.170758f), TypeParam(-0.081412f), TypeParam(0.319486f), TypeParam(0.630993f), TypeParam(-0.493702f), TypeParam(0.093438f), TypeParam(1.085657f), TypeParam(-0.679024f), TypeParam(-0.813753f), TypeParam(-0.920282f), TypeParam(0.717311f), TypeParam(-1.100678f), TypeParam(-0.583561f), TypeParam(0.810473f), TypeParam(-0.719377f), TypeParam(0.975857f), TypeParam(-0.560957f), TypeParam(0.189840f), TypeParam(0.157082f), TypeParam(-0.029434f), TypeParam(0.747413f), TypeParam(1.019186f), TypeParam(-0.749235f), TypeParam(0.673000f), TypeParam(0.320624f), TypeParam(-0.022362f), TypeParam(-0.839050f), TypeParam(0.355966f), TypeParam(0.871005f), TypeParam(-1.030007f), TypeParam(-1.108265f), TypeParam(-1.179701f), TypeParam(0.277273f), TypeParam(-0.344802f), TypeParam(-0.372753f), TypeParam(1.117390f), TypeParam(-0.306079f), TypeParam(-0.762057f), TypeParam(0.107942f), TypeParam(-0.658634f), TypeParam(-0.351593f), TypeParam(0.633875f), TypeParam(0.276953f), TypeParam(-0.823465f), TypeParam(1.142446f), TypeParam(0.811875f), TypeParam(-0.818022f), TypeParam(0.522699f), TypeParam(0.493103f), TypeParam(-0.861061f), TypeParam(-0.843352f), TypeParam(-0.993629f), TypeParam(0.534540f), TypeParam(0.209070f), TypeParam(0.507143f), TypeParam(-0.527071f), TypeParam(0.902309f), TypeParam(0.153227f), TypeParam(-0.957513f), TypeParam(-0.302041f), TypeParam(0.612404f), TypeParam(0.263859f), TypeParam(-0.183579f), TypeParam(-0.838388f), TypeParam(-0.746482f), TypeParam(1.035039f), TypeParam(-0.687403f), TypeParam(0.850371f), TypeParam(-0.401659f), TypeParam(0.011995f), TypeParam(-1.168548f), TypeParam(-0.390077f), TypeParam(1.011575f), TypeParam(-1.077360f), TypeParam(0.603794f), TypeParam(-1.009901f), TypeParam(0.175023f), TypeParam(-1.087964f), TypeParam(-0.949961f), TypeParam(-0.968757f), TypeParam(-0.416100f), TypeParam(0.163389f), TypeParam(-0.879807f), TypeParam(0.304124f), TypeParam(0.722748f), TypeParam(0.978239f), TypeParam(1.062535f), TypeParam(0.790067f), TypeParam(-0.353356f), TypeParam(-0.110591f), TypeParam(1.061730f), TypeParam(0.596951f), TypeParam(-0.318231f), TypeParam(0.905999f), TypeParam(-1.048710f), TypeParam(1.027042f), TypeParam(0.671407f), TypeParam(-0.880154f), TypeParam(-0.978736f), TypeParam(0.938431f), TypeParam(1.183815f), TypeParam(0.104716f), TypeParam(-0.468883f)}; std::initializer_list Y_shape{2, 2, 3, 3, 2}; - std::initializer_list Y_data{-0.414201f, 0.167816f, -0.042305f, -0.423495f, -0.101419f, 0.120192f, -1.543294f, 0.344146f, 0.709278f, 0.248721f, -0.269138f, 0.158159f, 0.659876f, 0.226329f, 0.874509f, 0.240959f, 0.412611f, 0.225904f, -0.448580f, 0.057703f, -0.426538f, -0.401142f, -0.147435f, 0.401852f, -0.355426f, -0.286018f, -0.219687f, -0.564205f, 0.282723f, 0.363522f, -0.543706f, -0.787722f, -0.692217f, -0.594894f, 0.091005f, -0.328214f, 0.919003f, 0.408116f, 0.631220f, 0.303619f, -0.197801f, -0.308153f, 0.094457f, 1.027881f, -0.077622f, -0.597219f, -0.661449f, 0.947805f, 0.279352f, 0.828246f, 0.571205f, 1.646163f, 0.714257f, 0.049881f, -1.680014f, -0.056047f, 0.892393f, 0.250564f, 0.138843f, 0.178706f, 0.161286f, 0.036891f, -0.141908f, -0.510903f, 0.733949f, -0.112944f, -0.581858f, -0.269439f, 0.056781f, 0.200325f, 0.814038f, 0.277386f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.414201f), TypeParam(0.167816f), TypeParam(-0.042305f), TypeParam(-0.423495f), TypeParam(-0.101419f), TypeParam(0.120192f), TypeParam(-1.543294f), TypeParam(0.344146f), TypeParam(0.709278f), TypeParam(0.248721f), TypeParam(-0.269138f), TypeParam(0.158159f), TypeParam(0.659876f), TypeParam(0.226329f), TypeParam(0.874509f), TypeParam(0.240959f), TypeParam(0.412611f), TypeParam(0.225904f), TypeParam(-0.448580f), TypeParam(0.057703f), TypeParam(-0.426538f), TypeParam(-0.401142f), TypeParam(-0.147435f), TypeParam(0.401852f), TypeParam(-0.355426f), TypeParam(-0.286018f), TypeParam(-0.219687f), TypeParam(-0.564205f), TypeParam(0.282723f), TypeParam(0.363522f), TypeParam(-0.543706f), TypeParam(-0.787722f), TypeParam(-0.692217f), TypeParam(-0.594894f), TypeParam(0.091005f), TypeParam(-0.328214f), TypeParam(0.919003f), TypeParam(0.408116f), TypeParam(0.631220f), TypeParam(0.303619f), TypeParam(-0.197801f), TypeParam(-0.308153f), TypeParam(0.094457f), TypeParam(1.027881f), TypeParam(-0.077622f), TypeParam(-0.597219f), TypeParam(-0.661449f), TypeParam(0.947805f), TypeParam(0.279352f), TypeParam(0.828246f), TypeParam(0.571205f), TypeParam(1.646163f), TypeParam(0.714257f), TypeParam(0.049881f), TypeParam(-1.680014f), TypeParam(-0.056047f), TypeParam(0.892393f), TypeParam(0.250564f), TypeParam(0.138843f), TypeParam(0.178706f), TypeParam(0.161286f), TypeParam(0.036891f), TypeParam(-0.141908f), TypeParam(-0.510903f), TypeParam(0.733949f), TypeParam(-0.112944f), TypeParam(-0.581858f), TypeParam(-0.269439f), TypeParam(0.056781f), TypeParam(0.200325f), TypeParam(0.814038f), TypeParam(0.277386f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_reflection_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_reflection_no_align_corners) { OpTester test("GridSample", 20); std::string mode = "linear"; std::string padding_mode = "reflection"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-0.173652f, -1.513725f, -0.704586f, -1.952375f, -0.699404f, -0.806298f, 1.640852f, -0.138969f, -0.695411f, -1.352111f, 0.568797f, -0.564294f, -0.056468f, 0.641604f, -0.438370f, 0.450167f, -1.091401f, 1.669729f, -0.908544f, 0.244467f, 0.172109f, 1.156741f, -0.617128f, 1.155460f}; + std::initializer_list X_data{TypeParam(-0.173652f), TypeParam(-1.513725f), TypeParam(-0.704586f), TypeParam(-1.952375f), TypeParam(-0.699404f), TypeParam(-0.806298f), TypeParam(1.640852f), TypeParam(-0.138969f), TypeParam(-0.695411f), TypeParam(-1.352111f), TypeParam(0.568797f), TypeParam(-0.564294f), TypeParam(-0.056468f), TypeParam(0.641604f), TypeParam(-0.438370f), TypeParam(0.450167f), TypeParam(-1.091401f), TypeParam(1.669729f), TypeParam(-0.908544f), TypeParam(0.244467f), TypeParam(0.172109f), TypeParam(1.156741f), TypeParam(-0.617128f), TypeParam(1.155460f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.252250f, -0.151452f, 0.824706f, -0.588292f, -0.591147f, -0.155082f, -0.732938f, 0.457493f, -0.439559f, 0.492330f, 0.696447f, 0.700722f, -0.220298f, 0.654884f, -0.635434f, -1.195619f, -0.114204f, -0.870080f, -0.929674f, 0.305035f, 1.025429f, -0.472240f, -0.067881f, -0.869393f}; + std::initializer_list Grid_data{TypeParam(0.252250f), TypeParam(-0.151452f), TypeParam(0.824706f), TypeParam(-0.588292f), TypeParam(-0.591147f), TypeParam(-0.155082f), TypeParam(-0.732938f), TypeParam(0.457493f), TypeParam(-0.439559f), TypeParam(0.492330f), TypeParam(0.696447f), TypeParam(0.700722f), TypeParam(-0.220298f), TypeParam(0.654884f), TypeParam(-0.635434f), TypeParam(-1.195619f), TypeParam(-0.114204f), TypeParam(-0.870080f), TypeParam(-0.929674f), TypeParam(0.305035f), TypeParam(1.025429f), TypeParam(-0.472240f), TypeParam(-0.067881f), TypeParam(-0.869393f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-1.538390f, -1.565293f, -0.581079f, -0.701030f, -0.725252f, -0.806298f, -0.850602f, -0.281588f, -0.151944f, 0.172138f, 0.177246f, -0.564294f, -0.316822f, -0.056468f, 0.212846f, -0.737167f, 0.585773f, 0.245182f, -0.111277f, -0.908544f, -0.463717f, -0.189009f, 0.510522f, -0.410307f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-1.538390f), TypeParam(-1.565293f), TypeParam(-0.581079f), TypeParam(-0.701030f), TypeParam(-0.725252f), TypeParam(-0.806298f), TypeParam(-0.850602f), TypeParam(-0.281588f), TypeParam(-0.151944f), TypeParam(0.172138f), TypeParam(0.177246f), TypeParam(-0.564294f), TypeParam(-0.316822f), TypeParam(-0.056468f), TypeParam(0.212846f), TypeParam(-0.737167f), TypeParam(0.585773f), TypeParam(0.245182f), TypeParam(-0.111277f), TypeParam(-0.908544f), TypeParam(-0.463717f), TypeParam(-0.189009f), TypeParam(0.510522f), TypeParam(-0.410307f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_reflection_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_reflection_no_align_corners) { OpTester test("GridSample", 20); std::string mode = "linear"; std::string padding_mode = "reflection"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 3, 2}; - std::initializer_list X_data{1.179856f, 1.432512f, 1.016210f, -0.661096f, 0.335863f, 0.565957f, -0.517555f, 2.232456f, -0.615173f, -0.073628f, -0.260768f, -1.952025f, 0.304237f, 0.902323f, -0.485170f, 0.781595f, -1.777093f, -0.274107f, -1.030698f, 0.181435f, 1.947646f, 1.007702f, -0.100718f, 0.154090f, -0.483193f, 1.565921f, -0.932274f, 0.313820f, -0.439116f, -0.411861f, -0.821795f, -1.685022f, -0.013518f, 0.519914f, -0.175407f, -0.507962f, 0.050913f, 0.981904f, 1.087165f, 1.758657f, 0.075954f, -0.481552f, 0.085590f, 0.537831f, -0.419622f, -1.756791f, 1.324879f, -0.267061f, -0.683518f, 0.605393f, 0.041004f, -0.756742f, 0.744950f, -0.508619f, -0.594679f, -1.165646f, -0.699604f, -0.271502f, 0.437731f, -2.206233f, 1.088781f, -0.629873f, -0.904741f, -1.233533f, 2.466710f, -0.117309f, -0.684130f, 0.598811f, 0.288846f, -1.195569f, 0.935300f, 0.962852f}; + std::initializer_list X_data{TypeParam(1.179856f), TypeParam(1.432512f), TypeParam(1.016210f), TypeParam(-0.661096f), TypeParam(0.335863f), TypeParam(0.565957f), TypeParam(-0.517555f), TypeParam(2.232456f), TypeParam(-0.615173f), TypeParam(-0.073628f), TypeParam(-0.260768f), TypeParam(-1.952025f), TypeParam(0.304237f), TypeParam(0.902323f), TypeParam(-0.485170f), TypeParam(0.781595f), TypeParam(-1.777093f), TypeParam(-0.274107f), TypeParam(-1.030698f), TypeParam(0.181435f), TypeParam(1.947646f), TypeParam(1.007702f), TypeParam(-0.100718f), TypeParam(0.154090f), TypeParam(-0.483193f), TypeParam(1.565921f), TypeParam(-0.932274f), TypeParam(0.313820f), TypeParam(-0.439116f), TypeParam(-0.411861f), TypeParam(-0.821795f), TypeParam(-1.685022f), TypeParam(-0.013518f), TypeParam(0.519914f), TypeParam(-0.175407f), TypeParam(-0.507962f), TypeParam(0.050913f), TypeParam(0.981904f), TypeParam(1.087165f), TypeParam(1.758657f), TypeParam(0.075954f), TypeParam(-0.481552f), TypeParam(0.085590f), TypeParam(0.537831f), TypeParam(-0.419622f), TypeParam(-1.756791f), TypeParam(1.324879f), TypeParam(-0.267061f), TypeParam(-0.683518f), TypeParam(0.605393f), TypeParam(0.041004f), TypeParam(-0.756742f), TypeParam(0.744950f), TypeParam(-0.508619f), TypeParam(-0.594679f), TypeParam(-1.165646f), TypeParam(-0.699604f), TypeParam(-0.271502f), TypeParam(0.437731f), TypeParam(-2.206233f), TypeParam(1.088781f), TypeParam(-0.629873f), TypeParam(-0.904741f), TypeParam(-1.233533f), TypeParam(2.466710f), TypeParam(-0.117309f), TypeParam(-0.684130f), TypeParam(0.598811f), TypeParam(0.288846f), TypeParam(-1.195569f), TypeParam(0.935300f), TypeParam(0.962852f)}; std::initializer_list Grid_shape{2, 3, 3, 2, 3}; - std::initializer_list Grid_data{0.625842f, 0.210304f, -0.725943f, -0.553764f, -0.182412f, -0.296478f, -0.254040f, -0.820211f, 0.869312f, 0.622346f, 0.236815f, 0.271706f, 0.140482f, 0.897281f, 0.271537f, 0.182799f, -0.659653f, 0.400310f, -1.122656f, 0.378466f, -1.040147f, -0.496646f, 0.633526f, -0.714734f, 0.955528f, -0.663024f, 1.136629f, 0.369854f, -0.520025f, 0.731855f, -1.062711f, -0.760189f, -0.751812f, 0.157968f, 0.117892f, -1.032129f, 1.157953f, -0.001147f, -0.640796f, 0.028663f, -0.515104f, 0.331070f, 0.434411f, -0.340393f, 0.069958f, 0.714010f, -0.780518f, -0.267586f, -0.177029f, -0.793935f, 0.097737f, 0.044103f, -0.969274f, 0.246164f, 1.145360f, 0.638273f, -0.650926f, 1.098440f, -0.824873f, -0.610135f, 0.529312f, 0.954650f, 1.145143f, 1.033109f, -0.660775f, 0.274592f, -0.753497f, 0.026500f, 0.994206f, 0.590870f, -1.108049f, -0.516447f, -1.012489f, 0.565286f, -0.152334f, -0.877228f, -0.383453f, 0.393797f, 0.111096f, 1.125969f, -0.015932f, 0.377468f, -0.363512f, 0.143194f, 0.042988f, 1.030777f, 0.502813f, -0.683870f, -1.066269f, -1.141727f, -0.435790f, 0.155118f, 1.128919f, -0.117905f, 0.469189f, 0.609870f, -0.919201f, -0.992659f, 0.454699f, 0.559331f, -0.558762f, 0.188050f, -1.174933f, 0.015126f, 0.294147f, 0.011359f, -0.190476f, 0.499476f}; + std::initializer_list Grid_data{TypeParam(0.625842f), TypeParam(0.210304f), TypeParam(-0.725943f), TypeParam(-0.553764f), TypeParam(-0.182412f), TypeParam(-0.296478f), TypeParam(-0.254040f), TypeParam(-0.820211f), TypeParam(0.869312f), TypeParam(0.622346f), TypeParam(0.236815f), TypeParam(0.271706f), TypeParam(0.140482f), TypeParam(0.897281f), TypeParam(0.271537f), TypeParam(0.182799f), TypeParam(-0.659653f), TypeParam(0.400310f), TypeParam(-1.122656f), TypeParam(0.378466f), TypeParam(-1.040147f), TypeParam(-0.496646f), TypeParam(0.633526f), TypeParam(-0.714734f), TypeParam(0.955528f), TypeParam(-0.663024f), TypeParam(1.136629f), TypeParam(0.369854f), TypeParam(-0.520025f), TypeParam(0.731855f), TypeParam(-1.062711f), TypeParam(-0.760189f), TypeParam(-0.751812f), TypeParam(0.157968f), TypeParam(0.117892f), TypeParam(-1.032129f), TypeParam(1.157953f), TypeParam(-0.001147f), TypeParam(-0.640796f), TypeParam(0.028663f), TypeParam(-0.515104f), TypeParam(0.331070f), TypeParam(0.434411f), TypeParam(-0.340393f), TypeParam(0.069958f), TypeParam(0.714010f), TypeParam(-0.780518f), TypeParam(-0.267586f), TypeParam(-0.177029f), TypeParam(-0.793935f), TypeParam(0.097737f), TypeParam(0.044103f), TypeParam(-0.969274f), TypeParam(0.246164f), TypeParam(1.145360f), TypeParam(0.638273f), TypeParam(-0.650926f), TypeParam(1.098440f), TypeParam(-0.824873f), TypeParam(-0.610135f), TypeParam(0.529312f), TypeParam(0.954650f), TypeParam(1.145143f), TypeParam(1.033109f), TypeParam(-0.660775f), TypeParam(0.274592f), TypeParam(-0.753497f), TypeParam(0.026500f), TypeParam(0.994206f), TypeParam(0.590870f), TypeParam(-1.108049f), TypeParam(-0.516447f), TypeParam(-1.012489f), TypeParam(0.565286f), TypeParam(-0.152334f), TypeParam(-0.877228f), TypeParam(-0.383453f), TypeParam(0.393797f), TypeParam(0.111096f), TypeParam(1.125969f), TypeParam(-0.015932f), TypeParam(0.377468f), TypeParam(-0.363512f), TypeParam(0.143194f), TypeParam(0.042988f), TypeParam(1.030777f), TypeParam(0.502813f), TypeParam(-0.683870f), TypeParam(-1.066269f), TypeParam(-1.141727f), TypeParam(-0.435790f), TypeParam(0.155118f), TypeParam(1.128919f), TypeParam(-0.117905f), TypeParam(0.469189f), TypeParam(0.609870f), TypeParam(-0.919201f), TypeParam(-0.992659f), TypeParam(0.454699f), TypeParam(0.559331f), TypeParam(-0.558762f), TypeParam(0.188050f), TypeParam(-1.174933f), TypeParam(0.015126f), TypeParam(0.294147f), TypeParam(0.011359f), TypeParam(-0.190476f), TypeParam(0.499476f)}; std::initializer_list Y_shape{2, 2, 3, 3, 2}; - std::initializer_list Y_data{-0.274014f, 0.145076f, 0.451342f, -0.273219f, -1.128307f, 0.962473f, 0.629978f, 0.370138f, 0.901663f, 0.778787f, 1.179856f, 0.014218f, -0.634683f, 0.585419f, 0.972130f, 1.911376f, 0.389205f, 0.849839f, 0.738424f, 0.054296f, -1.034114f, 0.096287f, -0.408114f, -0.474491f, 0.784791f, 0.001762f, -1.672976f, -1.127656f, -1.030698f, 1.105979f, 0.979492f, -0.258014f, 0.693543f, 1.010218f, -0.008927f, -0.078404f, -0.384825f, 0.944247f, -0.508619f, 0.548774f, 0.068986f, 0.881841f, 0.869967f, -0.274754f, 0.337312f, -0.374188f, 0.161655f, 0.050913f, 0.146763f, 0.119233f, -0.438980f, 0.228062f, -0.187221f, -0.376543f, -2.077576f, -1.120214f, 0.962852f, -0.133462f, 0.314542f, -1.044921f, 1.568017f, -0.060947f, 0.838264f, -0.652863f, 0.978122f, -0.594679f, 0.366536f, 0.596221f, -0.120431f, -0.435362f, -0.328892f, -0.434798f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.274014f), TypeParam(0.145076f), TypeParam(0.451342f), TypeParam(-0.273219f), TypeParam(-1.128307f), TypeParam(0.962473f), TypeParam(0.629978f), TypeParam(0.370138f), TypeParam(0.901663f), TypeParam(0.778787f), TypeParam(1.179856f), TypeParam(0.014218f), TypeParam(-0.634683f), TypeParam(0.585419f), TypeParam(0.972130f), TypeParam(1.911376f), TypeParam(0.389205f), TypeParam(0.849839f), TypeParam(0.738424f), TypeParam(0.054296f), TypeParam(-1.034114f), TypeParam(0.096287f), TypeParam(-0.408114f), TypeParam(-0.474491f), TypeParam(0.784791f), TypeParam(0.001762f), TypeParam(-1.672976f), TypeParam(-1.127656f), TypeParam(-1.030698f), TypeParam(1.105979f), TypeParam(0.979492f), TypeParam(-0.258014f), TypeParam(0.693543f), TypeParam(1.010218f), TypeParam(-0.008927f), TypeParam(-0.078404f), TypeParam(-0.384825f), TypeParam(0.944247f), TypeParam(-0.508619f), TypeParam(0.548774f), TypeParam(0.068986f), TypeParam(0.881841f), TypeParam(0.869967f), TypeParam(-0.274754f), TypeParam(0.337312f), TypeParam(-0.374188f), TypeParam(0.161655f), TypeParam(0.050913f), TypeParam(0.146763f), TypeParam(0.119233f), TypeParam(-0.438980f), TypeParam(0.228062f), TypeParam(-0.187221f), TypeParam(-0.376543f), TypeParam(-2.077576f), TypeParam(-1.120214f), TypeParam(0.962852f), TypeParam(-0.133462f), TypeParam(0.314542f), TypeParam(-1.044921f), TypeParam(1.568017f), TypeParam(-0.060947f), TypeParam(0.838264f), TypeParam(-0.652863f), TypeParam(0.978122f), TypeParam(-0.594679f), TypeParam(0.366536f), TypeParam(0.596221f), TypeParam(-0.120431f), TypeParam(-0.435362f), TypeParam(-0.328892f), TypeParam(-0.434798f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_zeros_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_zeros_align_corners) { OpTester test("GridSample", 20); std::string mode = "cubic"; std::string padding_mode = "zeros"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{0.741614f, -1.612838f, 0.274100f, -0.685296f, -0.032079f, -0.246424f, 0.089412f, -0.776545f, -0.152179f, 0.312533f, -1.503701f, -0.720829f, 0.877575f, 0.407229f, -0.889951f, 0.603605f, -0.140859f, 2.032775f, -0.520668f, 1.063163f, -1.008883f, 0.194195f, -0.303240f, -0.967884f}; + std::initializer_list X_data{TypeParam(0.741614f), TypeParam(-1.612838f), TypeParam(0.274100f), TypeParam(-0.685296f), TypeParam(-0.032079f), TypeParam(-0.246424f), TypeParam(0.089412f), TypeParam(-0.776545f), TypeParam(-0.152179f), TypeParam(0.312533f), TypeParam(-1.503701f), TypeParam(-0.720829f), TypeParam(0.877575f), TypeParam(0.407229f), TypeParam(-0.889951f), TypeParam(0.603605f), TypeParam(-0.140859f), TypeParam(2.032775f), TypeParam(-0.520668f), TypeParam(1.063163f), TypeParam(-1.008883f), TypeParam(0.194195f), TypeParam(-0.303240f), TypeParam(-0.967884f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{-0.932019f, -0.034394f, 0.554511f, 0.484230f, 0.141120f, 0.485083f, -0.836516f, 0.999462f, 0.026764f, 0.775689f, 0.265464f, -0.133497f, 0.514005f, 1.139161f, 1.183700f, -1.010095f, 0.072779f, -0.862052f, 0.699178f, 0.861473f, -0.842637f, -0.069355f, 0.830374f, 0.793568f}; + std::initializer_list Grid_data{TypeParam(-0.932019f), TypeParam(-0.034394f), TypeParam(0.554511f), TypeParam(0.484230f), TypeParam(0.141120f), TypeParam(0.485083f), TypeParam(-0.836516f), TypeParam(0.999462f), TypeParam(0.026764f), TypeParam(0.775689f), TypeParam(0.265464f), TypeParam(-0.133497f), TypeParam(0.514005f), TypeParam(1.139161f), TypeParam(1.183700f), TypeParam(-1.010095f), TypeParam(0.072779f), TypeParam(-0.862052f), TypeParam(0.699178f), TypeParam(0.861473f), TypeParam(-0.842637f), TypeParam(-0.069355f), TypeParam(0.830374f), TypeParam(0.793568f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{0.274192f, -0.348792f, -0.238780f, -0.048938f, -0.195915f, -0.488976f, -0.104505f, -0.351103f, -0.583059f, -1.533095f, -1.141282f, 0.187052f, 1.668728f, 0.345182f, 0.682750f, 1.893112f, -0.775917f, 1.920082f, -0.889375f, 1.071508f, 0.336517f, -0.933740f, -0.981629f, -0.893789f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(0.274192f), TypeParam(-0.348792f), TypeParam(-0.238780f), TypeParam(-0.048938f), TypeParam(-0.195915f), TypeParam(-0.488976f), TypeParam(-0.104505f), TypeParam(-0.351103f), TypeParam(-0.583059f), TypeParam(-1.533095f), TypeParam(-1.141282f), TypeParam(0.187052f), TypeParam(1.668728f), TypeParam(0.345182f), TypeParam(0.682750f), TypeParam(1.893112f), TypeParam(-0.775917f), TypeParam(1.920082f), TypeParam(-0.889375f), TypeParam(1.071508f), TypeParam(0.336517f), TypeParam(-0.933740f), TypeParam(-0.981629f), TypeParam(-0.893789f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_zeros_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_zeros_no_align_corners) { OpTester test("GridSample", 20); std::string mode = "cubic"; std::string padding_mode = "zeros"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{0.333395f, 0.977190f, 0.214232f, 0.363731f, -1.352515f, -0.980304f, -0.354887f, -0.481711f, -0.607915f, -0.309748f, 2.262781f, 0.963363f, 1.997079f, 0.987449f, -0.537662f, 1.011585f, 0.822184f, 0.567108f, 0.135401f, -0.943315f, -0.614181f, 0.030652f, 0.914757f, 0.971777f}; + std::initializer_list X_data{TypeParam(0.333395f), TypeParam(0.977190f), TypeParam(0.214232f), TypeParam(0.363731f), TypeParam(-1.352515f), TypeParam(-0.980304f), TypeParam(-0.354887f), TypeParam(-0.481711f), TypeParam(-0.607915f), TypeParam(-0.309748f), TypeParam(2.262781f), TypeParam(0.963363f), TypeParam(1.997079f), TypeParam(0.987449f), TypeParam(-0.537662f), TypeParam(1.011585f), TypeParam(0.822184f), TypeParam(0.567108f), TypeParam(0.135401f), TypeParam(-0.943315f), TypeParam(-0.614181f), TypeParam(0.030652f), TypeParam(0.914757f), TypeParam(0.971777f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{-0.487111f, 0.913573f, 0.641905f, -0.093110f, 0.512522f, 0.358369f, 0.655341f, -0.964320f, 0.370929f, -1.136512f, -0.789199f, -0.447185f, -0.116915f, -1.132446f, 0.029865f, 0.191588f, -0.476239f, 0.389224f, 1.048588f, -0.204978f, -0.639094f, -1.062994f, -0.876243f, -0.663705f}; + std::initializer_list Grid_data{TypeParam(-0.487111f), TypeParam(0.913573f), TypeParam(0.641905f), TypeParam(-0.093110f), TypeParam(0.512522f), TypeParam(0.358369f), TypeParam(0.655341f), TypeParam(-0.964320f), TypeParam(0.370929f), TypeParam(-1.136512f), TypeParam(-0.789199f), TypeParam(-0.447185f), TypeParam(-0.116915f), TypeParam(-1.132446f), TypeParam(0.029865f), TypeParam(0.191588f), TypeParam(-0.476239f), TypeParam(0.389224f), TypeParam(1.048588f), TypeParam(-0.204978f), TypeParam(-0.639094f), TypeParam(-1.062994f), TypeParam(-0.876243f), TypeParam(-0.663705f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-1.051920f, 0.501832f, -0.508839f, 0.563480f, 0.297178f, 0.246571f, 1.781955f, -0.353574f, 0.481200f, -0.258839f, -0.145200f, -0.469558f, 0.624262f, 0.351267f, 0.180256f, 0.571859f, 0.903895f, 1.383745f, -0.081406f, 0.133665f, 0.348401f, -0.164219f, 0.138237f, 0.203282f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-1.051920f), TypeParam(0.501832f), TypeParam(-0.508839f), TypeParam(0.563480f), TypeParam(0.297178f), TypeParam(0.246571f), TypeParam(1.781955f), TypeParam(-0.353574f), TypeParam(0.481200f), TypeParam(-0.258839f), TypeParam(-0.145200f), TypeParam(-0.469558f), TypeParam(0.624262f), TypeParam(0.351267f), TypeParam(0.180256f), TypeParam(0.571859f), TypeParam(0.903895f), TypeParam(1.383745f), TypeParam(-0.081406f), TypeParam(0.133665f), TypeParam(0.348401f), TypeParam(-0.164219f), TypeParam(0.138237f), TypeParam(0.203282f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_border_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_border_align_corners) { OpTester test("GridSample", 20); std::string mode = "cubic"; std::string padding_mode = "border"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-0.480448f, 0.682093f, 0.237716f, -1.234307f, 2.139750f, 2.410321f, 0.491472f, -0.553422f, 0.032129f, -0.162503f, 0.144036f, -1.889875f, -0.293944f, -1.390146f, -1.552136f, 1.604720f, -1.707202f, 0.182427f, -0.631000f, 0.196649f, 0.427711f, -0.014224f, -1.319834f, -2.703346f}; + std::initializer_list X_data{TypeParam(-0.480448f), TypeParam(0.682093f), TypeParam(0.237716f), TypeParam(-1.234307f), TypeParam(2.139750f), TypeParam(2.410321f), TypeParam(0.491472f), TypeParam(-0.553422f), TypeParam(0.032129f), TypeParam(-0.162503f), TypeParam(0.144036f), TypeParam(-1.889875f), TypeParam(-0.293944f), TypeParam(-1.390146f), TypeParam(-1.552136f), TypeParam(1.604720f), TypeParam(-1.707202f), TypeParam(0.182427f), TypeParam(-0.631000f), TypeParam(0.196649f), TypeParam(0.427711f), TypeParam(-0.014224f), TypeParam(-1.319834f), TypeParam(-2.703346f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.503717f, 0.572989f, 0.179517f, -0.060398f, 0.503876f, 0.288627f, -1.148268f, 0.194010f, -0.532910f, -0.636357f, 0.464076f, 0.245386f, 0.203212f, -0.569260f, 0.554489f, 1.126118f, 0.146805f, 0.493232f, -1.052794f, 0.713394f, 0.416866f, 0.540634f, 0.500415f, -0.315629f}; + std::initializer_list Grid_data{TypeParam(0.503717f), TypeParam(0.572989f), TypeParam(0.179517f), TypeParam(-0.060398f), TypeParam(0.503876f), TypeParam(0.288627f), TypeParam(-1.148268f), TypeParam(0.194010f), TypeParam(-0.532910f), TypeParam(-0.636357f), TypeParam(0.464076f), TypeParam(0.245386f), TypeParam(0.203212f), TypeParam(-0.569260f), TypeParam(0.554489f), TypeParam(1.126118f), TypeParam(0.146805f), TypeParam(0.493232f), TypeParam(-1.052794f), TypeParam(0.713394f), TypeParam(0.416866f), TypeParam(0.540634f), TypeParam(0.500415f), TypeParam(-0.315629f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{0.885659f, -0.722912f, -0.180469f, 0.697015f, -0.322127f, -0.292851f, -0.867861f, -0.047527f, -0.447720f, 0.028100f, 0.191874f, -0.378776f, -0.321888f, -0.277691f, -0.037604f, -1.766707f, 0.320836f, 0.415106f, 0.179209f, -2.609096f, -0.929794f, -0.788240f, -1.212243f, 0.337704f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(0.885659f), TypeParam(-0.722912f), TypeParam(-0.180469f), TypeParam(0.697015f), TypeParam(-0.322127f), TypeParam(-0.292851f), TypeParam(-0.867861f), TypeParam(-0.047527f), TypeParam(-0.447720f), TypeParam(0.028100f), TypeParam(0.191874f), TypeParam(-0.378776f), TypeParam(-0.321888f), TypeParam(-0.277691f), TypeParam(-0.037604f), TypeParam(-1.766707f), TypeParam(0.320836f), TypeParam(0.415106f), TypeParam(0.179209f), TypeParam(-2.609096f), TypeParam(-0.929794f), TypeParam(-0.788240f), TypeParam(-1.212243f), TypeParam(0.337704f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_border_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_border_no_align_corners) { OpTester test("GridSample", 20); std::string mode = "cubic"; std::string padding_mode = "border"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-0.924256f, -2.309784f, 1.272769f, 0.548427f, -1.478527f, -3.472946f, -1.252325f, 0.268589f, 0.326270f, 0.105016f, 0.515184f, -0.951158f, -0.658693f, -2.018776f, 0.981625f, -0.401504f, 1.560519f, -0.129836f, -1.876357f, 0.511516f, -1.825582f, 0.358958f, -0.805392f, -1.409127f}; + std::initializer_list X_data{TypeParam(-0.924256f), TypeParam(-2.309784f), TypeParam(1.272769f), TypeParam(0.548427f), TypeParam(-1.478527f), TypeParam(-3.472946f), TypeParam(-1.252325f), TypeParam(0.268589f), TypeParam(0.326270f), TypeParam(0.105016f), TypeParam(0.515184f), TypeParam(-0.951158f), TypeParam(-0.658693f), TypeParam(-2.018776f), TypeParam(0.981625f), TypeParam(-0.401504f), TypeParam(1.560519f), TypeParam(-0.129836f), TypeParam(-1.876357f), TypeParam(0.511516f), TypeParam(-1.825582f), TypeParam(0.358958f), TypeParam(-0.805392f), TypeParam(-1.409127f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.874856f, -1.090775f, 1.169192f, 0.447098f, 0.583418f, 0.267395f, 0.788144f, 1.129706f, -0.102229f, -0.984624f, 1.101916f, -0.253070f, -0.578731f, 0.738703f, 0.669694f, 0.160659f, -0.075327f, -0.229561f, 1.100291f, 0.731142f, 0.714643f, 0.765214f, -0.628031f, 0.250554f}; + std::initializer_list Grid_data{TypeParam(0.874856f), TypeParam(-1.090775f), TypeParam(1.169192f), TypeParam(0.447098f), TypeParam(0.583418f), TypeParam(0.267395f), TypeParam(0.788144f), TypeParam(1.129706f), TypeParam(-0.102229f), TypeParam(-0.984624f), TypeParam(1.101916f), TypeParam(-0.253070f), TypeParam(-0.578731f), TypeParam(0.738703f), TypeParam(0.669694f), TypeParam(0.160659f), TypeParam(-0.075327f), TypeParam(-0.229561f), TypeParam(1.100291f), TypeParam(0.731142f), TypeParam(0.714643f), TypeParam(0.765214f), TypeParam(-0.628031f), TypeParam(0.250554f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-2.647128f, -2.154235f, -0.768645f, -3.893546f, -1.698376f, -0.114530f, 0.458115f, -0.696657f, -0.370692f, -1.169692f, -0.754730f, 0.320002f, 1.683550f, -0.301499f, -0.176003f, -0.236653f, -0.278257f, 1.480160f, -0.700350f, 0.095525f, -0.891605f, -1.569065f, -1.633715f, -1.535763f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-2.647128f), TypeParam(-2.154235f), TypeParam(-0.768645f), TypeParam(-3.893546f), TypeParam(-1.698376f), TypeParam(-0.114530f), TypeParam(0.458115f), TypeParam(-0.696657f), TypeParam(-0.370692f), TypeParam(-1.169692f), TypeParam(-0.754730f), TypeParam(0.320002f), TypeParam(1.683550f), TypeParam(-0.301499f), TypeParam(-0.176003f), TypeParam(-0.236653f), TypeParam(-0.278257f), TypeParam(1.480160f), TypeParam(-0.700350f), TypeParam(0.095525f), TypeParam(-0.891605f), TypeParam(-1.569065f), TypeParam(-1.633715f), TypeParam(-1.535763f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_reflection_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_reflection_align_corners) { OpTester test("GridSample", 20); std::string mode = "cubic"; std::string padding_mode = "reflection"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-0.328038f, -0.658850f, -0.054298f, 0.012663f, -0.077366f, 0.644305f, -1.262985f, 0.922028f, 0.189962f, 0.518836f, 1.168413f, -0.286220f, 0.431207f, -0.295352f, -0.357675f, -0.311715f, 0.839514f, -0.651820f, -0.283934f, 0.430508f, 0.206334f, 0.765966f, -1.144732f, -0.507045f}; + std::initializer_list X_data{TypeParam(-0.328038f), TypeParam(-0.658850f), TypeParam(-0.054298f), TypeParam(0.012663f), TypeParam(-0.077366f), TypeParam(0.644305f), TypeParam(-1.262985f), TypeParam(0.922028f), TypeParam(0.189962f), TypeParam(0.518836f), TypeParam(1.168413f), TypeParam(-0.286220f), TypeParam(0.431207f), TypeParam(-0.295352f), TypeParam(-0.357675f), TypeParam(-0.311715f), TypeParam(0.839514f), TypeParam(-0.651820f), TypeParam(-0.283934f), TypeParam(0.430508f), TypeParam(0.206334f), TypeParam(0.765966f), TypeParam(-1.144732f), TypeParam(-0.507045f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{-0.372000f, -1.056863f, -0.360826f, -0.268314f, 0.691035f, -0.595044f, 0.720198f, 0.166462f, -0.201118f, -1.069416f, 1.184721f, -0.213980f, 0.755038f, -0.620722f, -1.168597f, -0.956522f, -0.614982f, -0.382162f, -0.169456f, 1.000817f, -1.106710f, 0.598940f, 1.009714f, 0.007723f}; + std::initializer_list Grid_data{TypeParam(-0.372000f), TypeParam(-1.056863f), TypeParam(-0.360826f), TypeParam(-0.268314f), TypeParam(0.691035f), TypeParam(-0.595044f), TypeParam(0.720198f), TypeParam(0.166462f), TypeParam(-0.201118f), TypeParam(-1.069416f), TypeParam(1.184721f), TypeParam(-0.213980f), TypeParam(0.755038f), TypeParam(-0.620722f), TypeParam(-1.168597f), TypeParam(-0.956522f), TypeParam(-0.614982f), TypeParam(-0.382162f), TypeParam(-0.169456f), TypeParam(1.000817f), TypeParam(-1.106710f), TypeParam(0.598940f), TypeParam(1.009714f), TypeParam(0.007723f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-0.403118f, -0.158055f, -0.496030f, 0.161379f, -0.440603f, -0.193607f, -0.746082f, -0.076433f, 0.751030f, 0.360851f, -0.488453f, 0.664305f, -0.259139f, 0.411796f, -0.156648f, 0.281569f, 0.437515f, -0.313812f, 0.573781f, -0.265706f, 0.200380f, -0.906155f, -0.724311f, 0.760352f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.403118f), TypeParam(-0.158055f), TypeParam(-0.496030f), TypeParam(0.161379f), TypeParam(-0.440603f), TypeParam(-0.193607f), TypeParam(-0.746082f), TypeParam(-0.076433f), TypeParam(0.751030f), TypeParam(0.360851f), TypeParam(-0.488453f), TypeParam(0.664305f), TypeParam(-0.259139f), TypeParam(0.411796f), TypeParam(-0.156648f), TypeParam(0.281569f), TypeParam(0.437515f), TypeParam(-0.313812f), TypeParam(0.573781f), TypeParam(-0.265706f), TypeParam(0.200380f), TypeParam(-0.906155f), TypeParam(-0.724311f), TypeParam(0.760352f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_reflection_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_reflection_no_align_corners) { OpTester test("GridSample", 20); std::string mode = "cubic"; std::string padding_mode = "reflection"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-0.290962f, 0.867797f, -0.085436f, -1.597520f, 0.695524f, 0.838739f, 0.513032f, 0.166242f, -0.546135f, -0.780313f, -0.512993f, -0.449479f, 1.594718f, 0.953375f, 0.692587f, -0.798364f, -0.128799f, -0.456210f, 2.098909f, -1.561220f, 1.713821f, -0.701970f, -0.287280f, -1.708048f}; + std::initializer_list X_data{TypeParam(-0.290962f), TypeParam(0.867797f), TypeParam(-0.085436f), TypeParam(-1.597520f), TypeParam(0.695524f), TypeParam(0.838739f), TypeParam(0.513032f), TypeParam(0.166242f), TypeParam(-0.546135f), TypeParam(-0.780313f), TypeParam(-0.512993f), TypeParam(-0.449479f), TypeParam(1.594718f), TypeParam(0.953375f), TypeParam(0.692587f), TypeParam(-0.798364f), TypeParam(-0.128799f), TypeParam(-0.456210f), TypeParam(2.098909f), TypeParam(-1.561220f), TypeParam(1.713821f), TypeParam(-0.701970f), TypeParam(-0.287280f), TypeParam(-1.708048f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.934471f, 0.728362f, -0.458301f, -1.040800f, 0.157908f, 0.753451f, -0.122762f, 0.100970f, 0.889432f, 0.495471f, 0.897108f, 0.176205f, 0.134514f, -0.287037f, -0.202498f, -0.637759f, 0.802292f, 1.094459f, 0.445338f, 0.034096f, -0.396126f, -1.184798f, -0.222199f, -0.851887f}; + std::initializer_list Grid_data{TypeParam(0.934471f), TypeParam(0.728362f), TypeParam(-0.458301f), TypeParam(-1.040800f), TypeParam(0.157908f), TypeParam(0.753451f), TypeParam(-0.122762f), TypeParam(0.100970f), TypeParam(0.889432f), TypeParam(0.495471f), TypeParam(0.897108f), TypeParam(0.176205f), TypeParam(0.134514f), TypeParam(-0.287037f), TypeParam(-0.202498f), TypeParam(-0.637759f), TypeParam(0.802292f), TypeParam(1.094459f), TypeParam(0.445338f), TypeParam(0.034096f), TypeParam(-0.396126f), TypeParam(-1.184798f), TypeParam(-0.222199f), TypeParam(-0.851887f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{1.037788f, -0.275160f, 0.953595f, -0.518196f, 0.118127f, -1.525148f, -0.413483f, 0.696689f, -0.450182f, -0.696169f, -0.561886f, -0.828986f, 0.343953f, 1.379632f, -0.417260f, -0.781500f, 1.666511f, 1.599268f, 0.106200f, 1.088396f, -2.079140f, -0.612122f, 1.822402f, 1.173807f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(1.037788f), TypeParam(-0.275160f), TypeParam(0.953595f), TypeParam(-0.518196f), TypeParam(0.118127f), TypeParam(-1.525148f), TypeParam(-0.413483f), TypeParam(0.696689f), TypeParam(-0.450182f), TypeParam(-0.696169f), TypeParam(-0.561886f), TypeParam(-0.828986f), TypeParam(0.343953f), TypeParam(1.379632f), TypeParam(-0.417260f), TypeParam(-0.781500f), TypeParam(1.666511f), TypeParam(1.599268f), TypeParam(0.106200f), TypeParam(1.088396f), TypeParam(-2.079140f), TypeParam(-0.612122f), TypeParam(1.822402f), TypeParam(1.173807f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } diff --git a/onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py b/onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py index c7e263ca3f654..bf58a5d3fc1d5 100644 --- a/onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py +++ b/onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py @@ -14,6 +14,17 @@ padding_modes = ["zeros", "border", "reflection"] align_corners_options = [True, False] +print( + """ +template +class GridSampleTest : public ::testing::Test { +}; + +using GridSampleTestTypes = ::testing::Types; +TYPED_TEST_SUITE(GridSampleTest, GridSampleTestTypes); + +""" +) # Loop over the combinations of parameters torch.manual_seed(0) for opset_version in [16, 20]: @@ -42,11 +53,15 @@ input_tensor, grid_tensor, mode=mode, padding_mode=padding_mode, align_corners=align_corners ) - X_data_str = "{" + ", ".join([f"{x:.6f}f" for x in input_tensor.numpy().flatten()]) + "}" - Grid_data_str = "{" + ", ".join([f"{x:.6f}f" for x in grid_tensor.numpy().flatten()]) + "}" + X_data_str = "{" + ", ".join([f"TypeParam({x:.6f}f)" for x in input_tensor.numpy().flatten()]) + "}" + Grid_data_str = ( + "{" + ", ".join([f"TypeParam({x:.6f}f)" for x in grid_tensor.numpy().flatten()]) + "}" + ) Y_shape = output_tensor.shape - Y_data_str = "{" + ", ".join([f"{x:.6f}f" for x in output_tensor.numpy().flatten()]) + "}" + Y_data_str = ( + "{" + ", ".join([f"TypeParam({x:.6f}f)" for x in output_tensor.numpy().flatten()]) + "}" + ) onnx_mode = mode if opset_version >= 20: @@ -58,24 +73,25 @@ onnx_align_corners = 1 if align_corners else 0 test_name = f"test_grid_sample_{opset_version}_{ndim}D_{mode}_{padding_mode}_{'align_corners' if align_corners else 'no_align_corners'}" - print(f"TEST(GridSampleTest, {test_name}) {{") - print(f'OpTester test("GridSample", {opset_version});') - print(f'std::string mode = "{onnx_mode}";') - print(f'std::string padding_mode = "{padding_mode}";') - print(f"int64_t align_corners = {onnx_align_corners};") - print(f"std::initializer_list X_shape {{ {', '.join(map(str, input_shape))} }};") - print(f"std::initializer_list X_data { X_data_str };") - print(f"std::initializer_list Grid_shape {{ {', '.join(map(str, grid_shape))} }};") - print(f"std::initializer_list Grid_data { Grid_data_str };") - print(f"std::initializer_list Y_shape {{ {', '.join(map(str, Y_shape))} }};") - print(f"std::initializer_list Y_data { Y_data_str };") - - print('test.AddInput("X", X_shape, X_data);') - print('test.AddInput("Grid", Grid_shape, Grid_data);') - print('test.AddAttribute("mode", mode);') - print('test.AddAttribute("padding_mode", padding_mode);') - print('test.AddAttribute("align_corners", align_corners);') - print('test.AddOutput("Y", Y_shape, Y_data);') - print(f"RunTests(test, GetExecutionProviders({opset_version}));") + spaces = " " + print(f"TYPED_TEST(GridSampleTest, {test_name}) {{") + print(f'{spaces}OpTester test("GridSample", {opset_version});') + print(f'{spaces}std::string mode = "{onnx_mode}";') + print(f'{spaces}std::string padding_mode = "{padding_mode}";') + print(f"{spaces}int64_t align_corners = {onnx_align_corners};") + print(f"{spaces}std::initializer_list X_shape {{ {', '.join(map(str, input_shape))} }};") + print(f"{spaces}std::initializer_list X_data { X_data_str };") + print(f"{spaces}std::initializer_list Grid_shape {{ {', '.join(map(str, grid_shape))} }};") + print(f"{spaces}std::initializer_list Grid_data { Grid_data_str };") + print(f"{spaces}std::initializer_list Y_shape {{ {', '.join(map(str, Y_shape))} }};") + print(f"{spaces}std::initializer_list Y_data { Y_data_str };") + + print(f'{spaces}test.AddInput("X", X_shape, X_data);') + print(f'{spaces}test.AddInput("Grid", Grid_shape, Grid_data);') + print(f'{spaces}test.AddAttribute("mode", mode);') + print(f'{spaces}test.AddAttribute("padding_mode", padding_mode);') + print(f'{spaces}test.AddAttribute("align_corners", align_corners);') + print(f'{spaces}test.AddOutput("Y", Y_shape, Y_data);') + print(f"{spaces}RunTests(test, GetExecutionProviders({opset_version}));") print("}") print("\n") diff --git a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc index 111520ef03e26..fc95764345710 100644 --- a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc @@ -10,6 +10,13 @@ namespace onnxruntime { namespace test { +template +class ResizeOpTest : public ::testing::Test { +}; + +using ResizeOpTestTypes = ::testing::Types; +TYPED_TEST_SUITE(ResizeOpTest, ResizeOpTestTypes); + TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_tf_crop_and_resize) { // TODO: Unskip when fixed #41968513 if (DefaultDmlExecutionProvider().get() != nullptr) { @@ -226,26 +233,26 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_without_e test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kDmlExecutionProvider}); } -TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear) { +TYPED_TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear) { auto run_test = [](bool scales_in_initializer) { OpTester test("Resize", 13); - std::vector roi{}; - std::vector scales{1.0f, 1.0f, 0.6f, 0.6f}; + std::vector roi{}; + std::vector scales{(1.0f), (1.0f), (0.6f), (0.6f)}; test.AddAttribute("mode", "linear"); constexpr int64_t N = 1, C = 1, H = 2, W = 4; - std::vector X = { - 1.0f, 2.0f, 3.0f, 4.0f, - 5.0f, 6.0f, 7.0f, 8.0f}; + std::vector X = { + TypeParam(1.0f), TypeParam(2.0f), TypeParam(3.0f), TypeParam(4.0f), + TypeParam(5.0f), TypeParam(6.0f), TypeParam(7.0f), TypeParam(8.0f)}; - test.AddInput("X", {N, C, H, W}, X); - test.AddInput("roi", {0}, roi); + test.AddInput("X", {N, C, H, W}, X); + test.AddInput("roi", {0}, roi); test.AddInput("scales", {4}, scales, scales_in_initializer); - std::vector Y = {2.66666651f, 4.3333331f}; + std::vector Y = {TypeParam(2.66666651f), TypeParam(4.3333331f)}; - test.AddOutput("Y", {N, C, static_cast(H * scales[2]), static_cast(W * scales[3])}, Y); + test.AddOutput("Y", {N, C, static_cast(H * scales[2]), static_cast(W * scales[3])}, Y); // QNN: result diff // TRT: Segmentation fault in A100 std::unordered_set excluded_providers({kQnnExecutionProvider}); From 753efb69f09c1de4c474766ea5f9c438b39bb0fe Mon Sep 17 00:00:00 2001 From: wejoncy Date: Fri, 27 Sep 2024 00:17:41 -0700 Subject: [PATCH 26/39] d --- .../core/providers/coreml/builders/impl/clip_op_builder.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc index 53a0bd405e8fc..f9d9cf92db423 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc @@ -60,8 +60,6 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto& output_name = output.Name(); float min, max; ORT_RETURN_IF_NOT(GetClipMinMax(model_builder.GetGraphViewer(), node, min, max, logger), "GetClipMinMax failed"); - // we already checked it and dtype must be existed. - auto input_dtype = node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); bool has_min = min != std::numeric_limits::lowest(); bool has_max = max != std::numeric_limits::max(); @@ -94,6 +92,8 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, Operation& clip_op = *op; AddOperationInput(clip_op, "x", input_name); + // we already checked it and dtype must be existed. + auto input_dtype = node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); // if min and max were attributes we need to add initializers. otherwise we use the existing inputs const bool min_max_attribs = node.SinceVersion() < 11; std::string_view min_name; From ea70f1cdb4805a12fd9ddd57ec89311661fb5512 Mon Sep 17 00:00:00 2001 From: wejoncy Date: Fri, 27 Sep 2024 01:36:27 -0700 Subject: [PATCH 27/39] more uts --- .../cpu/activation/activation_op_test.cc | 2 +- .../test/providers/cpu/nn/conv_fp16_test.cc | 8 +-- .../providers/cpu/nn/pool_fp16_op_test.cc | 4 +- .../test/providers/cpu/nn/pool_op_test.cc | 46 ++++++++++++--- .../cpu/tensor/space_depth_ops_test.cc | 59 +++++++++++++++---- 5 files changed, 93 insertions(+), 26 deletions(-) diff --git a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc index 395f17ad59a9d..724118d7419d2 100644 --- a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc +++ b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc @@ -125,7 +125,7 @@ TEST_F(ActivationOpTest, Relu) { {}, {}, /*is_tensorrt_supported=*/false, /*opset_version= */ 14); -#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(USE_QNN) || defined(COREML_ENABLE_MLPROGRAM) +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM) TestActivationOp( "Relu", input_values_fp16, diff --git a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc index 285f9ad05fef5..911e346772e82 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc @@ -3,7 +3,7 @@ #include "core/mlas/inc/mlas.h" -#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM) || defined(USE_QNN) +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM) #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" @@ -30,7 +30,7 @@ struct ConvOpAndTestAttributes { /* Please notice that, we have predefined macros in the head of the file -#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM)|| defined(USE_QNN). +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM) When we have these two macro defines, this UT will turn into green light and work. `NhwcFusedConv` in FP16 dtype is a contribe op and not well support by basic CPU ep. @@ -93,9 +93,7 @@ void TestConvFp16Op(const ConvOpAndTestAttributes& attributes, // Disable TensorRT because weight as input is not supported excluded_providers.insert(kTensorrtExecutionProvider); // QNN have issue with dynamic weight, auto pad with SAME_UPPER, SAME_LOWER - if (!weight_is_initializer || attributes.auto_pad == "SAME_UPPER" || - attributes.auto_pad == "SAME_LOWER" || - !attributes.activation.empty()) { + if (!weight_is_initializer || attributes.auto_pad == "SAME_UPPER" || attributes.auto_pad == "SAME_LOWER") { excluded_providers.insert(kQnnExecutionProvider); } if (!weight_is_initializer || !attributes.activation.empty()) { diff --git a/onnxruntime/test/providers/cpu/nn/pool_fp16_op_test.cc b/onnxruntime/test/providers/cpu/nn/pool_fp16_op_test.cc index b033ddbca23d6..7d736d41e804b 100644 --- a/onnxruntime/test/providers/cpu/nn/pool_fp16_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/pool_fp16_op_test.cc @@ -3,7 +3,7 @@ #include "core/mlas/inc/mlas.h" -#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM) #include "core/providers/cpu/nn/pool.h" #include "gtest/gtest.h" @@ -567,4 +567,4 @@ TEST(PoolFp16Test, GlobalAveragePool) { } // namespace test } // namespace onnxruntime -#endif \ No newline at end of file +#endif diff --git a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc index 885fb11c6e999..12f8e1694b29f 100644 --- a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc @@ -9,6 +9,14 @@ using namespace std; namespace onnxruntime { namespace test { + +template +class PoolTest : public ::testing::Test { +}; + +using PoolTestTypes = ::testing::Types; +TYPED_TEST_SUITE(PoolTest, PoolTestTypes); + // Disable TensorRT on some of the tests because "pads" attribute is not supported TEST(PoolTest, MaxPool) { @@ -63,13 +71,15 @@ TEST(PoolTest, MaxPool) { // Only CUDA kernel has float 16 support // Disable for now, still investigating the issue with cudnn lib -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(COREML_ENABLE_MLPROGRAM) TEST(PoolTest, MaxPool_F16) { +#if defined(USE_CUDA) int min_cuda_architecture = 530; if (!HasCudaEnvironment(min_cuda_architecture)) { LOGS_DEFAULT(WARNING) << "Hardware NOT support FP16"; return; } +#endif OpTester test("MaxPool"); test.AddAttribute("auto_pad", ""); @@ -672,7 +682,7 @@ TEST(PoolTest, MaxPool_10_DilationPadding_3d) { {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); } -TEST(PoolTest, GlobalMaxPool) { +TYPED_TEST(PoolTest, GlobalMaxPool) { OpTester test("GlobalMaxPool"); std::vector x_vals = {0.19151945412158966, 0.6221087574958801, 0.43772774934768677, @@ -743,12 +753,23 @@ TEST(PoolTest, GlobalMaxPool) { std::vector expected_dims = {1, 3, 1, 1}; std::vector expected_vals = {0.9920814633369446, 0.9820047616958618, 0.9946538209915161}; - test.AddInput("X", x_dims, x_vals); - test.AddOutput("Y", expected_dims, expected_vals); + if constexpr (std::is_same::value) { + test.AddInput("X", x_dims, x_vals); + test.AddOutput("Y", expected_dims, expected_vals); + } else { + std::vector x_vals_fp16(x_vals.size()); + std::vector expected_vals_fp16(expected_vals.size()); + + ConvertFloatToMLFloat16(x_vals.data(), x_vals_fp16.data(), x_vals.size()); + ConvertFloatToMLFloat16(expected_vals.data(), expected_vals_fp16.data(), expected_vals.size()); + test.AddInput("X", x_dims, x_vals_fp16); + test.AddOutput("Y", expected_dims, expected_vals_fp16); + } + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}); } -TEST(PoolTest, GlobalMaxPool3D) { +TYPED_TEST(PoolTest, GlobalMaxPool3D) { OpTester test("GlobalMaxPool"); std::vector x_vals = {0.19151945412158966, 0.6221087574958801, 0.43772774934768677, @@ -819,8 +840,19 @@ TEST(PoolTest, GlobalMaxPool3D) { std::vector expected_dims = {1, 3, 1, 1, 1}; std::vector expected_vals = {0.9920814633369446, 0.9820047616958618, 0.9946538209915161}; - test.AddInput("X", x_dims, x_vals); - test.AddOutput("Y", expected_dims, expected_vals); + if constexpr (std::is_same::value) { + test.AddInput("X", x_dims, x_vals); + test.AddOutput("Y", expected_dims, expected_vals); + } else { + std::vector x_vals_fp16(x_vals.size()); + std::vector expected_vals_fp16(expected_vals.size()); + + ConvertFloatToMLFloat16(x_vals.data(), x_vals_fp16.data(), x_vals.size()); + ConvertFloatToMLFloat16(expected_vals.data(), expected_vals_fp16.data(), expected_vals.size()); + test.AddInput("X", x_dims, x_vals_fp16); + test.AddOutput("Y", expected_dims, expected_vals_fp16); + } + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } diff --git a/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc b/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc index a0c1d675f506f..effb3e19933f3 100644 --- a/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc @@ -4,10 +4,18 @@ #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" #include "core/providers/cpu/tensor/space_depth_ops.h" +#include "core/mlas/inc/mlas.h" namespace onnxruntime { namespace test { +template +class TensorOpTest : public ::testing::Test { +}; + +using TensorOpTestTypes = ::testing::Types; +TYPED_TEST_SUITE(TensorOpTest, TensorOpTestTypes); + TEST(TensorOpTest, SpaceToDepthTest_1) { OpTester test("SpaceToDepth"); constexpr int64_t blocksize = 2; @@ -259,7 +267,7 @@ TEST(TensorOpTest, DepthToSpaceTest_2) { test.Run(); } -TEST(TensorOpTest, DepthToSpaceTest_3) { +TYPED_TEST(TensorOpTest, DepthToSpaceTest_3) { OpTester test("DepthToSpace", 11); // create an opset 11 model with missing default attribute constexpr int64_t blocksize = 2; test.AddAttribute("blocksize", blocksize); @@ -281,8 +289,6 @@ TEST(TensorOpTest, DepthToSpaceTest_3) { 132., 133., 134., 135., 136., 137., 138., 139., 140., 141., 142., 143.}; - test.AddInput("input", {N, C, H, W}, X); - const std::vector result = { 0., 18., 1., 19., 36., 54., 37., 55., 2., 20., 3., 21., 38., 56., 39., 57., 4., 22., 5., 23., 40., 58., @@ -298,11 +304,22 @@ TEST(TensorOpTest, DepthToSpaceTest_3) { 102., 85., 103., 120., 138., 121., 139., 86., 104., 87., 105., 122., 140., 123., 141., 88., 106., 89., 107., 124., 142., 125., 143.}; - test.AddOutput("output", {2, 3, 6, 4}, result); + + if constexpr (std::is_same::value) { + test.AddInput("input", {N, C, H, W}, X); + test.AddOutput("output", {2, 3, 6, 4}, result); + } else { + std::vector X_fp16(X.size()); + std::vector result_fp16(result.size()); + ConvertFloatToMLFloat16(result.data(), result_fp16.data(), result.size()); + ConvertFloatToMLFloat16(X.data(), X_fp16.data(), X.size()); + test.AddOutput("output", {2, 3, 6, 4}, result_fp16); + test.AddInput("input", {N, C, H, W}, X_fp16); + } test.Run(); } -TEST(TensorOpTest, DepthToSpaceTest_4) { +TYPED_TEST(TensorOpTest, DepthToSpaceTest_4) { OpTester test("DepthToSpace", 11); // create an opset 11 model with attribute present = "DCR" mode constexpr int64_t blocksize = 2; test.AddAttribute("blocksize", blocksize); @@ -325,7 +342,6 @@ TEST(TensorOpTest, DepthToSpaceTest_4) { 132., 133., 134., 135., 136., 137., 138., 139., 140., 141., 142., 143.}; - test.AddInput("input", {N, C, H, W}, X); const std::vector result = { 0., 18., 1., 19., 36., 54., 37., 55., 2., 20., 3., @@ -342,11 +358,23 @@ TEST(TensorOpTest, DepthToSpaceTest_4) { 102., 85., 103., 120., 138., 121., 139., 86., 104., 87., 105., 122., 140., 123., 141., 88., 106., 89., 107., 124., 142., 125., 143.}; - test.AddOutput("output", {2, 3, 6, 4}, result); + + if constexpr (std::is_same::value) { + test.AddInput("input", {N, C, H, W}, X); + test.AddOutput("output", {2, 3, 6, 4}, result); + } else { + std::vector X_fp16(X.size()); + std::vector result_fp16(result.size()); + ConvertFloatToMLFloat16(X.data(), X_fp16.data(), X.size()); + ConvertFloatToMLFloat16(result.data(), result_fp16.data(), result.size()); + test.AddInput("input", {N, C, H, W}, X_fp16); + test.AddOutput("output", {2, 3, 6, 4}, result_fp16); + } + test.Run(); } -TEST(TensorOpTest, DepthToSpaceTest_5) { +TYPED_TEST(TensorOpTest, DepthToSpaceTest_5) { OpTester test("DepthToSpace", 11); // create an opset 11 model with attribute present = "CRD" mode constexpr int64_t blocksize = 2; test.AddAttribute("blocksize", blocksize); @@ -362,14 +390,23 @@ TEST(TensorOpTest, DepthToSpaceTest_5) { 27., 28., 29., 30., 31., 32.}; - test.AddInput("input", {N, C, H, W}, X); - const std::vector result = {0., 9., 1., 10., 2., 11., 18., 27., 19., 28., 20., 29., 3., 12., 4., 13., 5., 14., 21., 30., 22., 31., 23., 32.}; - test.AddOutput("output", {1, 1, 4, 6}, result); + + if constexpr (std::is_same::value) { + test.AddInput("input", {N, C, H, W}, X); + test.AddOutput("output", {1, 1, 4, 6}, result); + } else { + std::vector X_fp16(X.size()); + std::vector result_fp16(result.size()); + ConvertFloatToMLFloat16(X.data(), X_fp16.data(), X.size()); + ConvertFloatToMLFloat16(result.data(), result_fp16.data(), result.size()); + test.AddInput("input", {N, C, H, W}, X_fp16); + test.AddOutput("output", {1, 1, 4, 6}, result_fp16); + } test.Run(); } From 46090adacc0a1443e60ba6c83a155722569e3517 Mon Sep 17 00:00:00 2001 From: wejoncy Date: Fri, 27 Sep 2024 02:25:12 -0700 Subject: [PATCH 28/39] address comments --- .../coreml/builders/impl/binary_op_builder.cc | 17 +++++++++++-- .../coreml/builders/impl/builder_utils.cc | 4 ++-- .../coreml/builders/impl/clip_op_builder.cc | 2 +- .../coreml/builders/impl/gemm_op_builder.cc | 6 ++--- .../core/providers/coreml/model/model.mm | 10 ++++---- .../cpu/math/element_wise_ops_test.cc | 4 +++- .../test/providers/cpu/math/gemm_test.cc | 24 +++++++++++++------ .../test/providers/cpu/nn/conv_fp16_test.cc | 10 ++++---- .../test/providers/cpu/nn/pool_op_test.cc | 1 - .../providers/cpu/tensor/resize_op_test.cc | 2 +- .../cpu/tensor/space_depth_ops_test.cc | 2 -- 11 files changed, 51 insertions(+), 31 deletions(-) diff --git a/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc index bc1eed8c1920a..8aa2dbae2531c 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc @@ -139,8 +139,21 @@ bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderIn // Add/Sub/Mul/Div spec says inputs must be of the same type. // Pow spec says inputs can be different types. // We support float/float16 for all of these inputs. - if (!IsInputDtypeSupport(node, 0, input_params, logger) || - ((node.OpType() == "Pow") && !IsInputDtypeSupport(node, 1, input_params, logger))) { + + if (node.OpType() == "Pow") { + const auto& input0 = *node.InputDefs()[0]; + const auto& input1 = *node.InputDefs()[1]; + int32_t input_type0 = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED; + int32_t input_type1 = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED; + if (!GetType(input0, input_type0, logger)) { + return false; + } + if (!GetType(input1, input_type1, logger) || input_type1 != input_type0) { + return false; + } + } + + if (!IsInputDtypeSupport(node, 0, input_params, logger)) { return false; } diff --git a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc index d053fc5b9496d..6f9bb35c27d80 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc @@ -118,7 +118,7 @@ void CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, gsl::span data) { - const char* data_byte_ptr = (const char*)(data.data()); + const char* data_byte_ptr = reinterpret_cast(data.data()); weight.mutable_float16value()->assign(data_byte_ptr, data_byte_ptr + data.size_bytes()); } @@ -137,7 +137,7 @@ void CreateCoreMLWeightConvertingDataToFloat16s(CoreML::Specification::WeightPar std::vector weight_float16s{}; weight_float16s.reserve(data.size()); std::transform(data.begin(), data.end(), std::back_inserter(weight_float16s), - [](T v) { return MLFloat16(narrow(v)); }); + [](T v) { return MLFloat16(float(v)); }); CreateCoreMLWeight(weight, weight_float16s); } } // namespace diff --git a/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc index f9d9cf92db423..bc9e2f10296ed 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc @@ -120,7 +120,7 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, } } } - std::cout << "3444444444444444444444444444444444444444444\n\n"; + AddOperationOutput(*op, output); model_builder.AddOperation(std::move(op)); } else diff --git a/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc index 71a4fe9b12035..e685c09ef43ca 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc @@ -141,12 +141,12 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N std::vector weight_nk_shape = {N, K}; // transpose from {K, N} to {N, K} if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { - std::vector weight_nk; // use bytes to store the type-erased data, could be any data-type + std::vector weight_nk; ORT_RETURN_IF_ERROR(GetTensorDataTransposed(*b_initializer, weight_nk)); AddOperationInput(*gemm_op, "weight", model_builder.AddConstant(gemm_op->type(), b.Name() + "_t", weight_nk, weight_nk_shape)); - } else { // TensorProto_DataType_FLOAT16 - std::vector weight_nk; // use bytes to store the type-erased data, could be any data-type + } else { // TensorProto_DataType_FLOAT16 + std::vector weight_nk; ORT_RETURN_IF_ERROR(GetTensorDataTransposed(*b_initializer, weight_nk)); AddOperationInput(*gemm_op, "weight", model_builder.AddConstant(gemm_op->type(), b.Name() + "_t", weight_nk, weight_nk_shape)); diff --git a/onnxruntime/core/providers/coreml/model/model.mm b/onnxruntime/core/providers/coreml/model/model.mm index 97e157d738371..1401cbe95fd56 100644 --- a/onnxruntime/core/providers/coreml/model/model.mm +++ b/onnxruntime/core/providers/coreml/model/model.mm @@ -179,8 +179,8 @@ Status CreateInputFeatureProvider(const std::unordered_map -void StrideCopy(const T* src_buffer, T* dst_buffer, size_t block_size, - size_t num_blocks, size_t src_stride, size_t dst_stride) { +void StridedCopy(const T* src_buffer, T* dst_buffer, size_t block_size, + size_t num_blocks, size_t src_stride, size_t dst_stride) { for (size_t idx = 0; idx < num_blocks; ++idx) { std::copy_n(src_buffer, block_size, dst_buffer); src_buffer += src_stride; @@ -210,21 +210,21 @@ Status CopyMLMultiArrayBuffer(const void* mlmultiarray_buffer, void* tensor_buff case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: { const auto* src_buffer = static_cast(mlmultiarray_buffer); auto* dst_buffer = static_cast(tensor_buffer); - StrideCopy(src_buffer, dst_buffer, block_size, num_blocks, stride, block_size); + StridedCopy(src_buffer, dst_buffer, block_size, num_blocks, stride, block_size); break; } case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: { const auto* src_buffer = static_cast(mlmultiarray_buffer); auto* dst_buffer = static_cast(tensor_buffer); - StrideCopy(src_buffer, dst_buffer, block_size, num_blocks, stride, block_size); + StridedCopy(src_buffer, dst_buffer, block_size, num_blocks, stride, block_size); break; } case ONNX_NAMESPACE::TensorProto_DataType_INT32: { const auto* src_buffer = static_cast(mlmultiarray_buffer); auto* dst_buffer = static_cast(tensor_buffer); - StrideCopy(src_buffer, dst_buffer, block_size, num_blocks, stride, block_size); + StridedCopy(src_buffer, dst_buffer, block_size, num_blocks, stride, block_size); break; } diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index ffb7d92a794d4..50f324f620a53 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -1099,7 +1099,7 @@ TEST(MathOpTest, Pow_float16_float16) { dims, {1.0f, 256.0f, 2.0f, 1.0f}, false); } -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(COREML_ENABLE_MLPROGRAM) TEST(MathOpTest, Pow_float_float16) { OpTester test("Pow", 12); std::vector dims{4}; @@ -1113,6 +1113,8 @@ TEST(MathOpTest, Pow_float_float16) { execution_providers.push_back(DefaultCudaExecutionProvider()); #elif USE_ROCM execution_providers.push_back(DefaultRocmExecutionProvider()); +#elif COREML_ENABLE_MLPROGRAM + execution_providers.push_back(DefaultCoreMLExecutionProvider(true)); #endif test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } diff --git a/onnxruntime/test/providers/cpu/math/gemm_test.cc b/onnxruntime/test/providers/cpu/math/gemm_test.cc index c21e353ca2fbb..42b728c8154b5 100644 --- a/onnxruntime/test/providers/cpu/math/gemm_test.cc +++ b/onnxruntime/test/providers/cpu/math/gemm_test.cc @@ -37,8 +37,11 @@ TEST(GemmOpTest, GemmNoTrans_f16) { std::vector A{1.0f, 2.0f, 3.0f, 4.0f, -1.0f, -2.0f, -3.0f, -4.0f}; - std::vector B = {0.5f, 2.1f, 1.2f, -0.3f, -1.2f, 0.2f, 1.0f, -2.1f, 1.3f, 4.1f, 1.3f, -8.1f}; - std::vector C = {0.5f, 2.1f, 1.2f, -0.3f, -1.2f, 0.2f}; + std::vector B = {0.5f, 2.1f, 1.2f, -0.3f, + -1.2f, 0.2f, 1.0f, -2.1f, + 1.3f, 4.1f, 1.3f, -8.1f}; + std::vector C = {0.5f, 2.1f, 1.2f, + -0.3f, -1.2f, 0.2f}; std::vector f_A(8); std::vector f_B(12); @@ -48,7 +51,8 @@ TEST(GemmOpTest, GemmNoTrans_f16) { { // bias has same shape as output std::vector f_Y(6); - std::vector Y{19.8f, 0.7f, -25.7f, -19.6f, 0.2f, 27.1f}; + std::vector Y{19.8f, 0.7f, -25.7f, + -19.6f, 0.2f, 27.1f}; ConvertFloatToMLFloat16(Y.data(), f_Y.data(), 6); std::vector f_C(6); @@ -64,6 +68,7 @@ TEST(GemmOpTest, GemmNoTrans_f16) { test.AddInput("B", {4, 3}, f_B); test.AddInput("C", {2, 3}, f_C); test.AddOutput("Y", {2, 3}, f_Y); + // we used float data with decimal instead of only integer, increase Tolerance to make test pass test.SetOutputTolerance(0.005f); test.ConfigExcludeEps({kTensorrtExecutionProvider}) // TensorRT: fp16 is not supported .Config(run_with_tunable_op) @@ -72,7 +77,8 @@ TEST(GemmOpTest, GemmNoTrans_f16) { { // bias has shape {1, output_features} std::vector f_Y(6); - std::vector Y{19.8f, 0.7f, -25.7f, -18.8f, 3.5f, 28.1f}; + std::vector Y{19.8f, 0.7f, -25.7f, + -18.8f, 3.5f, 28.1f}; ConvertFloatToMLFloat16(Y.data(), f_Y.data(), 6); std::vector f_C(3); @@ -96,7 +102,8 @@ TEST(GemmOpTest, GemmNoTrans_f16) { { // bias is a scalar std::vector f_Y(6); - std::vector Y{19.8f, -0.9f, -26.4f, -18.8f, 1.9f, 27.4f}; + std::vector Y{19.8f, -0.9f, -26.4f, + -18.8f, 1.9f, 27.4f}; ConvertFloatToMLFloat16(Y.data(), f_Y.data(), 6); std::vector f_C(1); @@ -130,8 +137,11 @@ TEST(GemmOpTest, GemmTransB_f16) { std::vector A{1.0f, 2.0f, 3.0f, 4.0f, -1.0f, -2.0f, -3.0f, -4.0f}; - std::vector B = {0.5f, 2.1f, 1.2f, -0.3f, -1.2f, 0.2f, 1.0f, -2.1f, 1.3f, 4.1f, 1.3f, -8.1f}; - std::vector C = {0.5f, 2.1f, 1.2f, -0.3f, -1.2f, 0.2f}; + std::vector B = {0.5f, 2.1f, 1.2f, -0.3f, + -1.2f, 0.2f, 1.0f, -2.1f, + 1.3f, 4.1f, 1.3f, -8.1f}; + std::vector C = {0.5f, 2.1f, 1.2f, + -0.3f, -1.2f, 0.2f}; std::vector f_A(8); std::vector f_B(12); diff --git a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc index 911e346772e82..ce1ac7591ec34 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc @@ -33,11 +33,9 @@ Please notice that, we have predefined macros in the head of the file #if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM) When we have these two macro defines, this UT will turn into green light and work. -`NhwcFusedConv` in FP16 dtype is a contribe op and not well support by basic CPU ep. -Once your EP can't satisfy all the conditions and capture the op, UT will crash as there -is no appropriate ep can handle this node. -As What CoreML did, if attributes has activation fused in, we should exclude CoreML ep -to let the test pass. +If attributes.activation is set the NhwcFusedConv contrib op is used. +If you are adding support for a new EP to the test and the EP does not support NhwcFusedConv +please add the EP to the excluded_providers list. */ void TestConvFp16Op(const ConvOpAndTestAttributes& attributes, const vector>& inputs, @@ -92,7 +90,7 @@ void TestConvFp16Op(const ConvOpAndTestAttributes& attributes, std::unordered_set excluded_providers(attributes.excluded_providers); // Disable TensorRT because weight as input is not supported excluded_providers.insert(kTensorrtExecutionProvider); - // QNN have issue with dynamic weight, auto pad with SAME_UPPER, SAME_LOWER + // QNN has issue with dynamic weight, auto pad with SAME_UPPER, SAME_LOWER if (!weight_is_initializer || attributes.auto_pad == "SAME_UPPER" || attributes.auto_pad == "SAME_LOWER") { excluded_providers.insert(kQnnExecutionProvider); } diff --git a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc index 12f8e1694b29f..a340f975ec91a 100644 --- a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc @@ -9,7 +9,6 @@ using namespace std; namespace onnxruntime { namespace test { - template class PoolTest : public ::testing::Test { }; diff --git a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc index fc95764345710..84fb6157b8884 100644 --- a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc @@ -237,7 +237,7 @@ TYPED_TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear) { auto run_test = [](bool scales_in_initializer) { OpTester test("Resize", 13); std::vector roi{}; - std::vector scales{(1.0f), (1.0f), (0.6f), (0.6f)}; + std::vector scales{1.0f, 1.0f, 0.6f, 0.6f}; test.AddAttribute("mode", "linear"); diff --git a/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc b/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc index effb3e19933f3..3841641264102 100644 --- a/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc @@ -342,7 +342,6 @@ TYPED_TEST(TensorOpTest, DepthToSpaceTest_4) { 132., 133., 134., 135., 136., 137., 138., 139., 140., 141., 142., 143.}; - const std::vector result = { 0., 18., 1., 19., 36., 54., 37., 55., 2., 20., 3., 21., 38., 56., 39., 57., 4., 22., 5., 23., 40., 58., @@ -395,7 +394,6 @@ TYPED_TEST(TensorOpTest, DepthToSpaceTest_5) { 3., 12., 4., 13., 5., 14., 21., 30., 22., 31., 23., 32.}; - if constexpr (std::is_same::value) { test.AddInput("input", {N, C, H, W}, X); test.AddOutput("output", {1, 1, 4, 6}, result); From a66f7b577775b5407e2b58d29893a631af0523d9 Mon Sep 17 00:00:00 2001 From: wejoncy Date: Fri, 27 Sep 2024 03:57:23 -0700 Subject: [PATCH 29/39] fix --- .../providers/coreml/builders/impl/depthtospace_op_builder.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc index bec2461ffbc52..ddaa19c7fab18 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc @@ -67,7 +67,9 @@ Status DepthToSpaceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, // we checked shape was static in IsOpSupportedImpl so this should never fail std::vector input_shape; ORT_RETURN_IF_NOT(GetStaticShape(*input_defs[0], input_shape, logger), "Failed to get input shape"); - const int32_t elem_type = static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + auto input_dtype = input_defs[0]->TypeAsProto()->tensor_type().elem_type(); + + const int32_t elem_type = static_cast(input_dtype); // reshape to [b * c // (blocksize ** 2), blocksize, blocksize, h, w] auto reshape1 = model_builder.CreateOperation(node, "reshape", "pre"); From 531d564abae04bb36398cd9b6d66aa58ce2fb251 Mon Sep 17 00:00:00 2001 From: wejoncy Date: Fri, 27 Sep 2024 05:46:00 -0700 Subject: [PATCH 30/39] fix --- .../providers/cpu/tensor/space_depth_ops_test.cc | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc b/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc index 3841641264102..f139d24993ed8 100644 --- a/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc @@ -316,7 +316,9 @@ TYPED_TEST(TensorOpTest, DepthToSpaceTest_3) { test.AddOutput("output", {2, 3, 6, 4}, result_fp16); test.AddInput("input", {N, C, H, W}, X_fp16); } - test.Run(); + // TODO: Test is flaky on QNN EP (CPU backend). Reneable when the QnnCPUBackendTests.DISABLED_SpaceToDepth_Flaky test + // is fixed. + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kQnnExecutionProvider}); } TYPED_TEST(TensorOpTest, DepthToSpaceTest_4) { @@ -370,7 +372,9 @@ TYPED_TEST(TensorOpTest, DepthToSpaceTest_4) { test.AddOutput("output", {2, 3, 6, 4}, result_fp16); } - test.Run(); + // TODO: Test is flaky on QNN EP (CPU backend). Reneable when the QnnCPUBackendTests.DISABLED_SpaceToDepth_Flaky test + // is fixed. + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kQnnExecutionProvider}); } TYPED_TEST(TensorOpTest, DepthToSpaceTest_5) { @@ -405,7 +409,9 @@ TYPED_TEST(TensorOpTest, DepthToSpaceTest_5) { test.AddInput("input", {N, C, H, W}, X_fp16); test.AddOutput("output", {1, 1, 4, 6}, result_fp16); } - test.Run(); + // TODO: Test is flaky on QNN EP (CPU backend). Reneable when the QnnCPUBackendTests.DISABLED_SpaceToDepth_Flaky test + // is fixed. + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kQnnExecutionProvider}); } TEST(TensorOpTest, DepthToSpaceTest_CRD_Batched) { From d88606d07d97b53e5173509f374c32ffc5e95403 Mon Sep 17 00:00:00 2001 From: wejoncy Date: Sun, 29 Sep 2024 02:39:58 -0700 Subject: [PATCH 31/39] convtranspose ut --- .../cpu/nn/conv_transpose_op_test.cc | 493 ++++++++++-------- 1 file changed, 267 insertions(+), 226 deletions(-) diff --git a/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc b/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc index 2bf53ce5b5986..96dd7ad63040a 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc @@ -22,10 +22,11 @@ struct ConvTransposeOpAttributes { string auto_pad; }; +template void TestConvTransposeOpInitializer(const ConvTransposeOpAttributes& attributes, - const vector>& inputs, + const vector>& inputs, const vector>& input_shapes, - const std::initializer_list& expected_output, + const std::vector& expected_output, const vector& expected_output_shape, bool is_weight_and_bias_initializer = false, OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess, @@ -61,17 +62,18 @@ void TestConvTransposeOpInitializer(const ConvTransposeOpAttributes& attributes, const char* input_names[] = {"X", "W", "B"}; bool is_initializers[] = {false, is_weight_and_bias_initializer, is_weight_and_bias_initializer}; for (size_t i = 0; i < inputs.size(); i++) { - test.AddInput(input_names[i], input_shapes[i], inputs[i], is_initializers[i]); + test.AddInput(input_names[i], input_shapes[i], inputs[i], is_initializers[i]); } - test.AddOutput("Y", expected_output_shape, expected_output); + test.AddOutput("Y", expected_output_shape, expected_output); test.Run(expect_result, err_str, excluded_provider_types); // Disable TensorRT because weight as input is not supported } +template void TestConvTransposeOp(const ConvTransposeOpAttributes& attributes, - const vector>& inputs, + const vector>& inputs, const vector>& input_shapes, - const std::initializer_list& expected_output, + const std::vector& expected_output, const vector& expected_output_shape, OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess, const std::string& err_str = "", @@ -87,6 +89,13 @@ void TestConvTransposeOp(const ConvTransposeOpAttributes& attributes, } // namespace +template +class ConvTransposeTest : public ::testing::Test { +}; + +using ConvTransposeTestTypes = ::testing::Types; +TYPED_TEST_SUITE(ConvTransposeTest, ConvTransposeTestTypes); + TEST(ConvTransposeTest, ConvTranspose_1D) { ConvTransposeOpAttributes attrs{ vector{3}, // kernel_shape @@ -108,13 +117,13 @@ TEST(ConvTransposeTest, ConvTranspose_1D) { 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f}; vector Y_shape = {1, 2, 5}; - auto expected_vals = {18.1f, 40.2f, 66.3f, 48.f, 26.f, - 9.4f, 22.5f, 39.6f, 30.f, 17.f}; + vector expected_vals = {18.1f, 40.2f, 66.3f, 48.f, 26.f, + 9.4f, 22.5f, 39.6f, 30.f, 17.f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } -TEST(ConvTransposeTest, ConvTranspose_2D_outputpadding_strides2) { +TYPED_TEST(ConvTransposeTest, ConvTranspose_2D_outputpadding_strides2) { ConvTransposeOpAttributes attrs = { vector{3, 3}, // kernel_shape vector{1, 1}, // output_padding @@ -137,17 +146,27 @@ TEST(ConvTransposeTest, ConvTranspose_2D_outputpadding_strides2) { 0.04118127f, -0.44696793f, 0.06373066f}; vector Y_shape = {1, 1, 6, 6}; - auto expected_vals = {0.07368518f, -0.08925839f, -0.06627201f, 0.06301362f, 0.03732984f, -0.01919658f, - -0.00628807f, -0.02817563f, -0.01472169f, 0.04392925f, -0.00689478f, -0.01549204f, - 0.07957941f, -0.11459791f, -0.09505399f, 0.07681622f, 0.03604182f, -0.01853423f, - -0.0270785f, -0.00680824f, -0.06650258f, 0.08004665f, 0.07918708f, -0.0724144f, - 0.06256775f, -0.17838378f, -0.18863615f, 0.20064656f, 0.133717f, -0.06876295f, - -0.06398046f, -0.00864975f, 0.19289537f, -0.01490572f, -0.13673618f, 0.01949645f}; - TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); + vector expected_vals = {0.07368518f, -0.08925839f, -0.06627201f, 0.06301362f, 0.03732984f, -0.01919658f, + -0.00628807f, -0.02817563f, -0.01472169f, 0.04392925f, -0.00689478f, -0.01549204f, + 0.07957941f, -0.11459791f, -0.09505399f, 0.07681622f, 0.03604182f, -0.01853423f, + -0.0270785f, -0.00680824f, -0.06650258f, 0.08004665f, 0.07918708f, -0.0724144f, + 0.06256775f, -0.17838378f, -0.18863615f, 0.20064656f, 0.133717f, -0.06876295f, + -0.06398046f, -0.00864975f, 0.19289537f, -0.01490572f, -0.13673618f, 0.01949645f}; + if constexpr (std::is_same::value) { + TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); + } else { + vector X_fp16(X.size()); + ConvertFloatToMLFloat16(X.data(), X_fp16.data(), X.size()); + vector W_fp16(W.size()); + ConvertFloatToMLFloat16(W.data(), W_fp16.data(), W.size()); + std::vector expected_vals_fp16(expected_vals.size()); + ConvertFloatToMLFloat16(expected_vals.data(), expected_vals_fp16.data(), expected_vals.size()); + TestConvTransposeOp(attrs, {X_fp16, W_fp16}, {X_shape, W_shape}, expected_vals_fp16, Y_shape); + } } // 2D input with C > 1 -TEST(ConvTransposeTest, ConvTranspose_2D_C2) { +TYPED_TEST(ConvTransposeTest, ConvTranspose_2D_C2) { ConvTransposeOpAttributes attrs = { vector{2, 2}, // kernel_shape {}, // output_padding @@ -176,16 +195,26 @@ TEST(ConvTransposeTest, ConvTranspose_2D_C2) { 0.44524362f, 0.6056068f}; vector Y_shape = {1, 1, 4, 4}; - auto expected_vals = { + vector expected_vals = { 0.50678771f, 1.10413539f, 0.74340409f, 0.14989006f, 0.34063845f, 1.19294512f, 1.85030293f, 0.63518577f, 0.58575004f, 1.25774109f, 1.23472511f, 0.77670550f, 0.25844323f, 0.88953220f, 0.77098041f, 0.27468451f}; - TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); + if constexpr (std::is_same::value) { + TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); + } else { + vector X_fp16(X.size()); + ConvertFloatToMLFloat16(X.data(), X_fp16.data(), X.size()); + vector W_fp16(W.size()); + ConvertFloatToMLFloat16(W.data(), W_fp16.data(), W.size()); + std::vector expected_vals_fp16(expected_vals.size()); + ConvertFloatToMLFloat16(expected_vals.data(), expected_vals_fp16.data(), expected_vals.size()); + TestConvTransposeOp(attrs, {X_fp16, W_fp16}, {X_shape, W_shape}, expected_vals_fp16, Y_shape); + } } -TEST(ConvTransposeTest, ConvTranspose_2D_Bias_1) { +TYPED_TEST(ConvTransposeTest, ConvTranspose_2D_Bias_1) { ConvTransposeOpAttributes attrs = { vector{3, 3}, // kernel_shape vector{0, 0}, // output_padding @@ -209,12 +238,24 @@ TEST(ConvTransposeTest, ConvTranspose_2D_Bias_1) { vector B = {0.04676145f}; vector B_shape = {1}; vector Y_shape = {1, 1, 5, 5}; - auto expected_vals = {-0.03781903f, -0.09041066f, 0.14239404f, 0.09704495f, -0.03399426f, - 0.08749044f, 0.35613984f, 0.07240347f, -0.27841991f, -0.00337578f, - 0.07770107f, -0.09561026f, 0.13388641f, 0.30945939f, 0.14015588f, - 0.13079405f, -0.00488365f, -0.06758944f, 0.45621645f, 0.01566098f, - 0.00703105f, 0.12956856f, 0.0103332f, 0.04221053f, -0.21318194f}; - TestConvTransposeOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); + vector expected_vals = {-0.03781903f, -0.09041066f, 0.14239404f, 0.09704495f, -0.03399426f, + 0.08749044f, 0.35613984f, 0.07240347f, -0.27841991f, -0.00337578f, + 0.07770107f, -0.09561026f, 0.13388641f, 0.30945939f, 0.14015588f, + 0.13079405f, -0.00488365f, -0.06758944f, 0.45621645f, 0.01566098f, + 0.00703105f, 0.12956856f, 0.0103332f, 0.04221053f, -0.21318194f}; + if constexpr (std::is_same::value) { + TestConvTransposeOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); + } else { + vector X_fp16(X.size()); + ConvertFloatToMLFloat16(X.data(), X_fp16.data(), X.size()); + vector W_fp16(W.size()); + ConvertFloatToMLFloat16(W.data(), W_fp16.data(), W.size()); + std::vector B_fp16(B.size()); + ConvertFloatToMLFloat16(B.data(), B_fp16.data(), B.size()); + std::vector expected_vals_fp16(expected_vals.size()); + ConvertFloatToMLFloat16(expected_vals.data(), expected_vals_fp16.data(), expected_vals.size()); + TestConvTransposeOp(attrs, {X_fp16, W_fp16, B_fp16}, {X_shape, W_shape}, expected_vals_fp16, Y_shape); + } } TEST(ConvTransposeTest, ConvTranspose_2D_Bias_2) { @@ -247,22 +288,22 @@ TEST(ConvTransposeTest, ConvTranspose_2D_Bias_2) { vector B = {0.17402864f}; vector B_shape = {1}; vector Y_shape = {1, 1, 8, 8}; - auto expected_vals = {0.1695925f, 0.14171794f, 0.31368554f, 0.16113512f, - 0.15653302f, 0.033998f, 0.38345876f, 0.12173492f, - 0.05362644f, 0.35481372f, 0.09013268f, -0.06378071f, - 0.24394518f, 0.00222442f, 0.50842237f, -0.07341707f, - 0.17984779f, 0.35392997f, 0.03631867f, 0.16350585f, - 0.30338728f, 0.2088346f, 0.47435546f, 0.0147884f, - 0.20821247f, 0.08664516f, 0.03569011f, 0.16659322f, - 0.47522858f, 0.19675478f, -0.10781619f, 0.02401161f, - 0.0965334f, 0.1788421f, 0.36887163f, 0.2512877f, - 0.00254938f, 0.04799958f, 0.11982619f, 0.31525785f, - 0.12701407f, 0.19566584f, 0.31214368f, -0.10558143f, - 0.18591091f, 0.46830338f, 0.05418756f, 0.20530567f, - 0.07357728f, 0.39731777f, 0.1872202f, 0.08253923f, - 0.11266428f, 0.17892915f, 0.32709083f, 0.1860041f, - 0.16902491f, 0.3129794f, -0.01718347f, 0.28917417f, - 0.07588299f, 0.32025051f, 0.39891475f, -0.04581133f}; + vector expected_vals = {0.1695925f, 0.14171794f, 0.31368554f, 0.16113512f, + 0.15653302f, 0.033998f, 0.38345876f, 0.12173492f, + 0.05362644f, 0.35481372f, 0.09013268f, -0.06378071f, + 0.24394518f, 0.00222442f, 0.50842237f, -0.07341707f, + 0.17984779f, 0.35392997f, 0.03631867f, 0.16350585f, + 0.30338728f, 0.2088346f, 0.47435546f, 0.0147884f, + 0.20821247f, 0.08664516f, 0.03569011f, 0.16659322f, + 0.47522858f, 0.19675478f, -0.10781619f, 0.02401161f, + 0.0965334f, 0.1788421f, 0.36887163f, 0.2512877f, + 0.00254938f, 0.04799958f, 0.11982619f, 0.31525785f, + 0.12701407f, 0.19566584f, 0.31214368f, -0.10558143f, + 0.18591091f, 0.46830338f, 0.05418756f, 0.20530567f, + 0.07357728f, 0.39731777f, 0.1872202f, 0.08253923f, + 0.11266428f, 0.17892915f, 0.32709083f, 0.1860041f, + 0.16902491f, 0.3129794f, -0.01718347f, 0.28917417f, + 0.07588299f, 0.32025051f, 0.39891475f, -0.04581133f}; TestConvTransposeOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); } @@ -292,18 +333,18 @@ TEST(ConvTransposeTest, ConvTranspose_2D_OutputShape_1) { vector W_shape = {3, 3, 3, 3}; vector Y_shape = {1, 3, 4, 4}; - auto expected_vals = {12.0f, 18.0f, 18.0f, 12.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 12.0f, 18.0f, 18.0f, 12.0f, - 12.0f, 18.0f, 18.0f, 12.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 12.0f, 18.0f, 18.0f, 12.0f, - 12.0f, 18.0f, 18.0f, 12.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 12.0f, 18.0f, 18.0f, 12.0f}; + vector expected_vals = {12.0f, 18.0f, 18.0f, 12.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 12.0f, 18.0f, 18.0f, 12.0f, + 12.0f, 18.0f, 18.0f, 12.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 12.0f, 18.0f, 18.0f, 12.0f, + 12.0f, 18.0f, 18.0f, 12.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 12.0f, 18.0f, 18.0f, 12.0f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kQnnExecutionProvider}); @@ -338,12 +379,12 @@ TEST(ConvTransposeTest, ConvTranspose_1D_OutputShape_1_group_2_for_transpose_pat vector W_shape = {6, 3, 3}; vector Y_shape = {1, 6, 4}; - auto expected_vals = {6.0f, 9.0f, 9.0f, 6.0f, - 6.0f, 9.0f, 9.0f, 6.0f, - 6.0f, 9.0f, 9.0f, 6.0f, - 6.0f, 9.0f, 9.0f, 6.0f, - 6.0f, 9.0f, 9.0f, 6.0f, - 6.0f, 9.0f, 9.0f, 6.0f}; + vector expected_vals = {6.0f, 9.0f, 9.0f, 6.0f, + 6.0f, 9.0f, 9.0f, 6.0f, + 6.0f, 9.0f, 9.0f, 6.0f, + 6.0f, 9.0f, 9.0f, 6.0f, + 6.0f, 9.0f, 9.0f, 6.0f, + 6.0f, 9.0f, 9.0f, 6.0f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, OpTester::ExpectResult::kExpectSuccess, "", @@ -376,30 +417,30 @@ TEST(ConvTransposeTest, ConvTranspose_2D_OutputShape_1_group_2_for_transpose_pat vector W_shape = {6, 3, 3, 3}; vector Y_shape = {1, 6, 4, 4}; - auto expected_vals = {12.0f, 18.0f, 18.0f, 12.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 12.0f, 18.0f, 18.0f, 12.0f, - 12.0f, 18.0f, 18.0f, 12.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 12.0f, 18.0f, 18.0f, 12.0f, - 12.0f, 18.0f, 18.0f, 12.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 12.0f, 18.0f, 18.0f, 12.0f, // duplicate below - 12.0f, 18.0f, 18.0f, 12.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 12.0f, 18.0f, 18.0f, 12.0f, - 12.0f, 18.0f, 18.0f, 12.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 12.0f, 18.0f, 18.0f, 12.0f, - 12.0f, 18.0f, 18.0f, 12.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 12.0f, 18.0f, 18.0f, 12.0f}; + vector expected_vals = {12.0f, 18.0f, 18.0f, 12.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 12.0f, 18.0f, 18.0f, 12.0f, + 12.0f, 18.0f, 18.0f, 12.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 12.0f, 18.0f, 18.0f, 12.0f, + 12.0f, 18.0f, 18.0f, 12.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 12.0f, 18.0f, 18.0f, 12.0f, // duplicate below + 12.0f, 18.0f, 18.0f, 12.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 12.0f, 18.0f, 18.0f, 12.0f, + 12.0f, 18.0f, 18.0f, 12.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 12.0f, 18.0f, 18.0f, 12.0f, + 12.0f, 18.0f, 18.0f, 12.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 12.0f, 18.0f, 18.0f, 12.0f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, OpTester::ExpectResult::kExpectSuccess, "", @@ -424,7 +465,7 @@ TEST(ConvTransposeTest, ConvTranspose_2D_OutputShape_2) { vector B = {1.0f}; vector B_shape = {1}; vector Y_shape = {1, 1, 1, 14}; - auto expected_vals = {1.0f, 2.0f, 5.0f, 11.0f, 19.0f, 28.0f, 37.0f, 46.0f, 55.0f, 64.0f, 63.0f, 51.0f, 27.0f, 10.0f}; + vector expected_vals = {1.0f, 2.0f, 5.0f, 11.0f, 19.0f, 28.0f, 37.0f, 46.0f, 55.0f, 64.0f, 63.0f, 51.0f, 27.0f, 10.0f}; TestConvTransposeOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider, kCudaNHWCExecutionProvider, kQnnExecutionProvider}); @@ -449,8 +490,8 @@ TEST(ConvTransposeTest, ConvTranspose_2D_OutputShapeWithBatchSize) { vector B = {1.0f}; vector B_shape = {1}; vector Y_shape = {2, 1, 1, 14}; - auto expected_vals = {1.0f, 2.0f, 5.0f, 11.0f, 19.0f, 28.0f, 37.0f, 46.0f, 55.0f, 64.0f, 63.0f, 51.0f, 27.0f, 10.0f, - 11.0f, 32.0f, 65.0f, 91.0f, 109.0f, 118.0f, 127.0f, 136.0f, 145.0f, 154.0f, 143.0f, 111.0f, 57.0f, 20.0f}; + vector expected_vals = {1.0f, 2.0f, 5.0f, 11.0f, 19.0f, 28.0f, 37.0f, 46.0f, 55.0f, 64.0f, 63.0f, 51.0f, 27.0f, 10.0f, + 11.0f, 32.0f, 65.0f, 91.0f, 109.0f, 118.0f, 127.0f, 136.0f, 145.0f, 154.0f, 143.0f, 111.0f, 57.0f, 20.0f}; TestConvTransposeOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider, kCudaNHWCExecutionProvider, kQnnExecutionProvider}); @@ -475,8 +516,8 @@ TEST(ConvTransposeTest, ConvTranspose_InvalidKernelShape) { vector B = {1.0f}; vector B_shape = {1}; vector Y_shape = {2, 1, 1, 14}; - auto expected_vals = {1.0f, 2.0f, 5.0f, 11.0f, 19.0f, 28.0f, 37.0f, 46.0f, 55.0f, 64.0f, 63.0f, 51.0f, 27.0f, 10.0f, - 11.0f, 32.0f, 65.0f, 91.0f, 109.0f, 118.0f, 127.0f, 136.0f, 145.0f, 154.0f, 143.0f, 111.0f, 57.0f, 20.0f}; + vector expected_vals = {1.0f, 2.0f, 5.0f, 11.0f, 19.0f, 28.0f, 37.0f, 46.0f, 55.0f, 64.0f, 63.0f, 51.0f, 27.0f, 10.0f, + 11.0f, 32.0f, 65.0f, 91.0f, 109.0f, 118.0f, 127.0f, 136.0f, 145.0f, 154.0f, 143.0f, 111.0f, 57.0f, 20.0f}; TestConvTransposeOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, OpTester::ExpectResult::kExpectFailure, // error message will end in "W: {1,1,1,5}" or "W: {1,1,5,1} depending on whether NCHW or NHWC, @@ -502,7 +543,7 @@ TEST(ConvTransposeTest, ConvTranspose_onnx) { vector W = {0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17.}; vector W_shape = {1, 2, 3, 3}; vector Y_shape = {1, 2, 5, 5}; - auto expected_vals = { + vector expected_vals = { 0.f, 0.f, 1.f, 4.f, 4.f, 0.f, 6.f, 20.f, 26.f, 20.f, 9.f, 36.f, 84.f, 84.f, 57.f, @@ -533,7 +574,7 @@ TEST(ConvTransposeTest, ConvTranspose_onnx2) { vector W = {0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23.}; vector W_shape = {2, 3, 2, 2}; // this requires weight transpose vector Y_shape = {1, 3, 4, 4}; - auto expected_vals = { + vector expected_vals = { 108.f, 237.f, 263.f, 145.f, 270.f, 592.f, 652.f, 358.f, 354.f, 772.f, 832.f, 454.f, @@ -566,7 +607,7 @@ TEST(ConvTransposeTest, ConvTranspose_onnx_group) { vector W = {0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.0f}; vector W_shape = {16, 2, 1, 1}; vector Y_shape = {1, 8, 1, 1}; - auto expected_vals = {28.f, 34.f, 252.f, 274.f, 732.f, 770.f, 1468.f, 1522.f}; + vector expected_vals = {28.f, 34.f, 252.f, 274.f, 732.f, 770.f, 1468.f, 1522.f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } @@ -586,10 +627,10 @@ TEST(ConvTransposeTest, ConvTranspose_2D_Dilation_1) { vector W = {1.0f, 1.0f, 1.0f, 1.0f}; vector W_shape = {1, 1, 2, 2}; vector Y_shape = {1, 1, 4, 4}; - auto expected_vals = {11.0f, 12.0f, 11.0f, 12.0f, - 21.0f, 22.0f, 21.0f, 22.0f, - 11.0f, 12.0f, 11.0f, 12.0f, - 21.0f, 22.0f, 21.0f, 22.0f}; + vector expected_vals = {11.0f, 12.0f, 11.0f, 12.0f, + 21.0f, 22.0f, 21.0f, 22.0f, + 11.0f, 12.0f, 11.0f, 12.0f, + 21.0f, 22.0f, 21.0f, 22.0f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } @@ -609,11 +650,11 @@ TEST(ConvTransposeTest, ConvTranspose_2D_Dilation_2) { vector W = {1.0f, 1.0f, 1.0f, 1.0f}; vector W_shape = {1, 1, 2, 2}; vector Y_shape = {1, 1, 5, 5}; - auto expected_vals = {11.0f, 12.0f, 0.0f, 11.0f, 12.0f, - 21.0f, 22.0f, 0.0f, 21.0f, 22.0f, - 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, - 11.0f, 12.0f, 0.0f, 11.0f, 12.0f, - 21.0f, 22.0f, 0.0f, 21.0f, 22.0f}; + vector expected_vals = {11.0f, 12.0f, 0.0f, 11.0f, 12.0f, + 21.0f, 22.0f, 0.0f, 21.0f, 22.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 11.0f, 12.0f, 0.0f, 11.0f, 12.0f, + 21.0f, 22.0f, 0.0f, 21.0f, 22.0f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } @@ -633,11 +674,11 @@ TEST(ConvTransposeTest, ConvTranspose_2D_Dilation_3) { vector W = {7.0f, 2.0f, 1.0f, 9.0f}; vector W_shape = {1, 1, 2, 2}; vector Y_shape = {1, 1, 5, 5}; - auto expected_vals = {21.0f, 56.0f, 13.0f, 16.0f, 2.0f, - 63.0f, 35.0f, 67.0f, 10.0f, 14.0f, - 24.0f, 22.0f, 76.0f, 76.0f, 21.0f, - 9.0f, 5.0f, 88.0f, 45.0f, 63.0f, - 3.0f, 2.0f, 33.0f, 18.0f, 54.0f}; + vector expected_vals = {21.0f, 56.0f, 13.0f, 16.0f, 2.0f, + 63.0f, 35.0f, 67.0f, 10.0f, 14.0f, + 24.0f, 22.0f, 76.0f, 76.0f, 21.0f, + 9.0f, 5.0f, 88.0f, 45.0f, 63.0f, + 3.0f, 2.0f, 33.0f, 18.0f, 54.0f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } @@ -658,12 +699,12 @@ TEST(ConvTransposeTest, ConvTranspose_2D_Dilation_4) { vector W = {7.0f, 2.0f, 1.0f, 9.0f}; vector W_shape = {1, 1, 2, 2}; vector Y_shape = {1, 1, 6, 6}; - auto expected_vals = {21.0f, 56.0f, 7.0f, 6.0f, 16.0f, 2.0f, - 63.0f, 35.0f, 49.0f, 18.0f, 10.0f, 14.0f, - 21.0f, 14.0f, 42.0f, 6.0f, 4.0f, 12.0f, - 3.0f, 8.0f, 1.0f, 27.0f, 72.0f, 9.0f, - 9.0f, 5.0f, 7.0f, 81.0f, 45.0f, 63.0f, - 3.0f, 2.0f, 6.0f, 27.0f, 18.0f, 54.0f}; + vector expected_vals = {21.0f, 56.0f, 7.0f, 6.0f, 16.0f, 2.0f, + 63.0f, 35.0f, 49.0f, 18.0f, 10.0f, 14.0f, + 21.0f, 14.0f, 42.0f, 6.0f, 4.0f, 12.0f, + 3.0f, 8.0f, 1.0f, 27.0f, 72.0f, 9.0f, + 9.0f, 5.0f, 7.0f, 81.0f, 45.0f, 63.0f, + 3.0f, 2.0f, 6.0f, 27.0f, 18.0f, 54.0f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } @@ -684,9 +725,9 @@ TEST(ConvTransposeTest, ConvTranspose_2D_Dilation_AsymmetricPads_1) { vector W = {7.0f, 2.0f, 1.0f, 9.0f}; vector W_shape = {1, 1, 2, 2}; vector Y_shape = {1, 1, 3, 3}; - auto expected_vals = {42.0f, 6.0f, 4.0f, - 1.0f, 27.0f, 72.0f, - 7.0f, 81.0f, 45.0f}; + vector expected_vals = {42.0f, 6.0f, 4.0f, + 1.0f, 27.0f, 72.0f, + 7.0f, 81.0f, 45.0f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } @@ -707,9 +748,9 @@ TEST(ConvTransposeTest, ConvTranspose_2D_Dilation_AsymmetricPads_2) { vector W = {7.0f, 2.0f, 1.0f, 9.0f}; vector W_shape = {1, 1, 2, 2}; vector Y_shape = {1, 1, 3, 3}; - auto expected_vals = {35.0f, 49.0f, 18.0f, - 14.0f, 42.0f, 6.0f, - 8.0f, 1.0f, 27.0f}; + vector expected_vals = {35.0f, 49.0f, 18.0f, + 14.0f, 42.0f, 6.0f, + 8.0f, 1.0f, 27.0f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } @@ -730,10 +771,10 @@ TEST(ConvTransposeTest, ConvTranspose_2D_Dilation_AsymmetricPads_3) { vector W = {7.0f, 2.0f, 1.0f, 9.0f}; vector W_shape = {1, 1, 2, 2}; vector Y_shape = {1, 1, 4, 4}; - auto expected_vals = {42.0f, 6.0f, 4.0f, 12.0f, - 1.0f, 27.0f, 72.0f, 9.0f, - 7.0f, 81.0f, 45.0f, 63.0f, - 6.0f, 27.0f, 18.0f, 54.0f}; + vector expected_vals = {42.0f, 6.0f, 4.0f, 12.0f, + 1.0f, 27.0f, 72.0f, 9.0f, + 7.0f, 81.0f, 45.0f, 63.0f, + 6.0f, 27.0f, 18.0f, 54.0f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } @@ -754,10 +795,10 @@ TEST(ConvTransposeTest, ConvTranspose_2D_Dilation_AsymmetricPads_4) { vector W = {7.0f, 2.0f, 1.0f, 9.0f}; vector W_shape = {1, 1, 2, 2}; vector Y_shape = {1, 1, 4, 4}; - auto expected_vals = {21.0f, 56.0f, 7.0f, 6.0f, - 63.0f, 35.0f, 49.0f, 18.0f, - 21.0f, 14.0f, 42.0f, 6.0f, - 3.0f, 8.0f, 1.0f, 27.0f}; + vector expected_vals = {21.0f, 56.0f, 7.0f, 6.0f, + 63.0f, 35.0f, 49.0f, 18.0f, + 21.0f, 14.0f, 42.0f, 6.0f, + 3.0f, 8.0f, 1.0f, 27.0f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } @@ -778,16 +819,16 @@ TEST(ConvTransposeTest, ConvTranspose_2D_Dilation_Group_1) { vector W = {9.0f, 3.0f, 1.0f, 2.0f, 3.0f, 7.0f, 0.0f, 8.0f}; vector W_shape = {2, 1, 2, 2}; vector Y_shape = {1, 2, 5, 5}; - auto expected_vals = {27.0f, 72.0f, 18.0f, 24.0f, 3.0f, - 81.0f, 45.0f, 90.0f, 15.0f, 21.0f, - 30.0f, 26.0f, 43.0f, 22.0f, 11.0f, - 9.0f, 5.0f, 25.0f, 10.0f, 14.0f, - 3.0f, 2.0f, 9.0f, 4.0f, 6.0f, - 21.0f, 27.0f, 52.0f, 63.0f, 7.0f, - 15.0f, 6.0f, 44.0f, 14.0f, 21.0f, - 27.0f, 0.0f, 125.0f, 72.0f, 22.0f, - 0.0f, 0.0f, 40.0f, 16.0f, 24.0f, - 0.0f, 0.0f, 72.0f, 0.0f, 16.0f}; + vector expected_vals = {27.0f, 72.0f, 18.0f, 24.0f, 3.0f, + 81.0f, 45.0f, 90.0f, 15.0f, 21.0f, + 30.0f, 26.0f, 43.0f, 22.0f, 11.0f, + 9.0f, 5.0f, 25.0f, 10.0f, 14.0f, + 3.0f, 2.0f, 9.0f, 4.0f, 6.0f, + 21.0f, 27.0f, 52.0f, 63.0f, 7.0f, + 15.0f, 6.0f, 44.0f, 14.0f, 21.0f, + 27.0f, 0.0f, 125.0f, 72.0f, 22.0f, + 0.0f, 0.0f, 40.0f, 16.0f, 24.0f, + 0.0f, 0.0f, 72.0f, 0.0f, 16.0f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } @@ -808,7 +849,7 @@ TEST(ConvTransposeTest, ConvTranspose_DefaultStridesAndDilations) { vector W = {0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23.}; vector W_shape = {2, 3, 2, 2}; // this requires weight transpose vector Y_shape = {1, 3, 4, 4}; - auto expected_vals = { + vector expected_vals = { 108.f, 237.f, 263.f, 145.f, 270.f, 592.f, 652.f, 358.f, 354.f, 772.f, 832.f, 454.f, @@ -841,7 +882,7 @@ TEST(ConvTransposeTest, ConvTranspose_2D_NonDefaultStridesAndDilations) { vector W = {1., 1., 1., 1.}; vector W_shape = {1, 1, 1, 4}; vector Y_shape = {1, 1, 1, 12}; - auto expected_vals = {1.f, 0.f, 2.f, 1.f, 0.f, 2.f, 1.f, 0.f, 2.f, 1.f, 0.f, 2.f}; + vector expected_vals = {1.f, 0.f, 2.f, 1.f, 0.f, 2.f, 1.f, 0.f, 2.f, 1.f, 0.f, 2.f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } @@ -862,7 +903,7 @@ TEST(ConvTransposeTest, ConvTranspose_2D_NonDefaultStridesAndDilations_T) { vector W = {1., 1., 1., 1.}; vector W_shape = {1, 1, 4, 1}; vector Y_shape = {1, 1, 12, 1}; - auto expected_vals = {1.f, 0.f, 2.f, 1.f, 0.f, 2.f, 1.f, 0.f, 2.f, 1.f, 0.f, 2.f}; + vector expected_vals = {1.f, 0.f, 2.f, 1.f, 0.f, 2.f, 1.f, 0.f, 2.f, 1.f, 0.f, 2.f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } @@ -885,7 +926,7 @@ TEST(ConvTransposeTest, DimWithZero) { 0.04118127f, -0.44696793f, 0.06373066f}; vector W_shape = {1, 1, 3, 3}; vector Y_shape = {0, 1, 6, 6}; - initializer_list expected_vals = {}; + vector expected_vals = {}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, OpTester::ExpectResult::kExpectSuccess, "", @@ -952,75 +993,75 @@ TEST(ConvTransposeTest, ConvTranspose_3D) { vector B = {-0.11784090101718903f, -0.060990236699581146f}; vector Y_shape = {1, 2, 5, 6, 7}; - auto expected_vals = {-0.08241813f, -0.06676699f, -0.13411677f, -0.15724352f, -0.18772511f, -0.11080553f, -0.114930674f, - -0.0953398f, -0.111061305f, -0.0413035f, -0.10902196f, -0.071916685f, -0.102583766f, -0.13639182f, - -0.21214074f, -0.18799849f, -0.15122052f, 0.00434383f, -0.011207409f, -0.11604968f, -0.08378546f, - -0.1722928f, -0.044016793f, -0.1914465f, -0.16952308f, -0.39505655f, 0.080385f, -0.15767722f, - -0.060116887f, -0.16235165f, -0.075614765f, -0.14631891f, 0.05837299f, -0.31712085f, -0.13272354f, - -0.08320008f, -0.1967324f, -0.033198006f, -0.06718128f, -0.2568521f, 0.0314174f, -0.15864298f, - - -0.13070306f, -0.09003539f, -0.29147533f, -0.024966106f, 0.079442084f, -0.096389435f, -0.09941827f, - -0.3365072f, -0.4451772f, -0.13154466f, -0.08992967f, -0.16572365f, 0.06494926f, -0.21230686f, - -0.11307171f, -0.056943115f, -0.35291147f, -0.317253f, -0.070464894f, -0.6300395f, -0.031246513f, - 0.19395588f, 0.011135533f, 0.096916616f, -0.3942836f, -0.29872403f, 0.16881491f, -0.24881886f, - -0.038873613f, -0.032735735f, -0.21593677f, 0.088557824f, 0.13849314f, -0.30753696f, -0.07219358f, - -0.15177673f, -0.09156879f, -0.2286228f, 0.080623806f, -0.39201033f, 0.07819712f, -0.19924995f, - - -0.3376814f, -0.033524483f, 0.230105f, -0.0377952f, -0.12315659f, -0.28858358f, -0.13848148f, - -0.16134796f, 0.012239918f, 0.27276647f, 0.020731017f, -0.4651906f, -0.14341736f, -0.07956973f, - 0.1342433f, -0.16956037f, 0.310399f, 0.34338957f, -0.040192716f, 0.12504166f, -0.21490449f, - -0.15410437f, -0.1338158f, -0.39244395f, 0.29117042f, -0.26415867f, -0.4450379f, 0.0699404f, - 0.042872816f, -0.14961651f, -0.17582522f, -0.6919577f, -0.13723494f, -0.0681901f, -0.16183335f, - -0.0021959245f, -0.0418434f, -0.32134426f, 0.16967098f, -0.08680786f, -0.32077473f, 0.0066963434f, - - -0.114091426f, -0.041066267f, -0.080250874f, -0.72594404f, -0.30254412f, -0.03862554f, -0.27475363f, - 0.15282185f, -0.22887689f, -0.72043663f, -0.47111863f, -0.3755179f, -0.20074406f, 0.16101281f, - -0.20939936f, -0.21245953f, 0.11726546f, -0.8030824f, -0.5866715f, 0.20001571f, -0.26259118f, - 0.17054747f, 0.061063558f, -0.6348493f, 0.2620284f, -0.782919f, -0.31278569f, 0.2926497f, - -0.08745579f, 0.20646049f, -0.050303012f, -0.13460274f, 0.060659587f, -0.037006564f, -0.1292249f, - -0.11211421f, -0.038967483f, -0.21644044f, -0.24912538f, 0.08591288f, -0.40798867f, 0.006527111f, - - -0.049734667f, -0.3685795f, -0.11538547f, 0.27292788f, 0.025990233f, 0.119311824f, 0.0700129f, - -0.156443f, -0.13340846f, 0.10764159f, -0.014803357f, 0.046525866f, 0.015691683f, -0.1869241f, - 0.1004442f, -0.4885978f, -0.7585998f, -0.047841772f, -0.07570776f, 0.0471261f, 0.24483289f, - -0.16554686f, -0.1250152f, -0.15132052f, -0.08515984f, 0.14412321f, -0.1030291f, -0.2780918f, - 0.05803944f, -0.10257156f, -0.4341917f, -0.13150966f, -0.53996617f, -0.15628646f, 0.059058204f, - -0.11976162f, -0.022163756f, -0.13519828f, -0.20148787f, 0.16934697f, -0.14327072f, -0.2129095f, - - -0.107836396f, -0.0819309f, -0.06148723f, -0.0063935146f, -0.02425649f, -0.056219954f, -0.06095987f, - -0.14403576f, -0.025357183f, -0.15828207f, 0.012748428f, -0.16061643f, -0.03419252f, -0.05130991f, - -0.109983265f, -0.08312916f, -0.07035978f, -0.008285124f, -0.10610263f, -0.01489019f, -0.106886685f, - -0.007659614f, -0.2947925f, -0.09132287f, -0.040577132f, 0.089866154f, -0.24528673f, -0.055424154f, - 0.13783869f, 0.023674607f, -0.10545369f, -0.20873478f, -0.4685722f, 0.09418375f, -0.06684458f, - 0.0410614f, 0.04018917f, -0.15845582f, 0.06580096f, 0.070554025f, -0.19462511f, -0.03526502f, - - -0.02956047f, -0.16035908f, -0.0638171f, -0.261022f, -0.022948403f, 0.08353848f, -0.041173913f, - 0.04770004f, 0.091520615f, 0.006987013f, -0.39962748f, 0.23266485f, -0.32719564f, -0.12885109f, - -0.29559937f, -0.08031146f, 0.76168066f, 0.0009028502f, -0.4091536f, -0.14801738f, -0.17058557f, - -0.05754847f, 0.2955231f, -0.089874476f, 0.17254886f, -0.13203058f, -0.007648442f, 0.010943003f, - 0.04123217f, 0.26074114f, -0.24313056f, 0.1008903f, -0.26472318f, 0.01998391f, -0.03422378f, - -0.024659738f, 0.033793047f, -0.1998924f, -0.110185415f, 0.10620246f, -0.3435271f, 0.019390412f, - - 0.21691665f, -0.26076952f, -0.5040901f, 0.28383943f, -0.34750903f, -0.32484284f, -0.01734912f, - -0.08909689f, -0.0466362f, 0.21648785f, 0.06733417f, 0.009496197f, 0.18728223f, -0.35110205f, - -0.04908372f, -0.36729553f, -0.346236f, -0.13589534f, -0.16435221f, -0.16853788f, 0.12264759f, - -0.019215636f, -0.38316554f, 0.35669535f, -0.56980205f, -0.059346225f, 0.15008381f, -0.1751053f, - 0.059508912f, 0.116622455f, -0.32607535f, -0.22282779f, -0.29149055f, -0.3829086f, 0.15905643f, - -0.077926554f, 0.06549884f, -0.09004557f, -0.15897253f, 0.26810864f, -0.08931713f, -0.047756508f, - - -0.14657992f, 0.43070868f, -0.021787114f, -0.4532621f, 0.092385404f, -0.30126676f, -0.24893704f, - -0.10896815f, -0.14514503f, -0.21353528f, 0.018723361f, 0.037694372f, 0.11514955f, 0.13013864f, - -0.25713888f, -0.056000195f, -0.3505367f, 0.0836427f, -0.032017898f, -0.26742116f, -0.14740711f, - -0.13330215f, -0.18958306f, -0.08968873f, 0.014723815f, -0.20343366f, 0.3098968f, 0.114284225f, - -0.026738256f, -0.14110464f, -0.054464605f, -0.17529932f, -0.0030034669f, -0.050670102f, -0.04016705f, - -0.062238634f, -0.04886609f, -0.042247344f, -0.12185234f, 0.0357792f, -0.10265522f, -0.116296895f, - - -0.1035416f, -0.09126053f, 0.20045105f, 0.12366664f, 0.05460281f, 0.09944453f, -0.055443168f, - -0.09767935f, -0.040166672f, -0.01716708f, 0.020299219f, 0.02864775f, -0.07159522f, -0.04354491f, - -0.1390779f, -0.13270372f, 0.02992779f, -0.025869183f, 0.12530258f, 0.05101595f, -0.07891131f, - -0.1051311f, -0.093200594f, -0.10368025f, 0.047598884f, -0.12069465f, -0.098738566f, -0.042393237f, - -0.08531736f, -0.051284637f, -0.04354899f, -0.06810297f, -0.083224006f, -0.11702064f, -0.08514082f, - -0.06071842f, -0.07496775f, -0.03626109f, -0.07785503f, -0.07243007f, -0.041736744f, -0.052593358f}; + vector expected_vals = {-0.08241813f, -0.06676699f, -0.13411677f, -0.15724352f, -0.18772511f, -0.11080553f, -0.114930674f, + -0.0953398f, -0.111061305f, -0.0413035f, -0.10902196f, -0.071916685f, -0.102583766f, -0.13639182f, + -0.21214074f, -0.18799849f, -0.15122052f, 0.00434383f, -0.011207409f, -0.11604968f, -0.08378546f, + -0.1722928f, -0.044016793f, -0.1914465f, -0.16952308f, -0.39505655f, 0.080385f, -0.15767722f, + -0.060116887f, -0.16235165f, -0.075614765f, -0.14631891f, 0.05837299f, -0.31712085f, -0.13272354f, + -0.08320008f, -0.1967324f, -0.033198006f, -0.06718128f, -0.2568521f, 0.0314174f, -0.15864298f, + + -0.13070306f, -0.09003539f, -0.29147533f, -0.024966106f, 0.079442084f, -0.096389435f, -0.09941827f, + -0.3365072f, -0.4451772f, -0.13154466f, -0.08992967f, -0.16572365f, 0.06494926f, -0.21230686f, + -0.11307171f, -0.056943115f, -0.35291147f, -0.317253f, -0.070464894f, -0.6300395f, -0.031246513f, + 0.19395588f, 0.011135533f, 0.096916616f, -0.3942836f, -0.29872403f, 0.16881491f, -0.24881886f, + -0.038873613f, -0.032735735f, -0.21593677f, 0.088557824f, 0.13849314f, -0.30753696f, -0.07219358f, + -0.15177673f, -0.09156879f, -0.2286228f, 0.080623806f, -0.39201033f, 0.07819712f, -0.19924995f, + + -0.3376814f, -0.033524483f, 0.230105f, -0.0377952f, -0.12315659f, -0.28858358f, -0.13848148f, + -0.16134796f, 0.012239918f, 0.27276647f, 0.020731017f, -0.4651906f, -0.14341736f, -0.07956973f, + 0.1342433f, -0.16956037f, 0.310399f, 0.34338957f, -0.040192716f, 0.12504166f, -0.21490449f, + -0.15410437f, -0.1338158f, -0.39244395f, 0.29117042f, -0.26415867f, -0.4450379f, 0.0699404f, + 0.042872816f, -0.14961651f, -0.17582522f, -0.6919577f, -0.13723494f, -0.0681901f, -0.16183335f, + -0.0021959245f, -0.0418434f, -0.32134426f, 0.16967098f, -0.08680786f, -0.32077473f, 0.0066963434f, + + -0.114091426f, -0.041066267f, -0.080250874f, -0.72594404f, -0.30254412f, -0.03862554f, -0.27475363f, + 0.15282185f, -0.22887689f, -0.72043663f, -0.47111863f, -0.3755179f, -0.20074406f, 0.16101281f, + -0.20939936f, -0.21245953f, 0.11726546f, -0.8030824f, -0.5866715f, 0.20001571f, -0.26259118f, + 0.17054747f, 0.061063558f, -0.6348493f, 0.2620284f, -0.782919f, -0.31278569f, 0.2926497f, + -0.08745579f, 0.20646049f, -0.050303012f, -0.13460274f, 0.060659587f, -0.037006564f, -0.1292249f, + -0.11211421f, -0.038967483f, -0.21644044f, -0.24912538f, 0.08591288f, -0.40798867f, 0.006527111f, + + -0.049734667f, -0.3685795f, -0.11538547f, 0.27292788f, 0.025990233f, 0.119311824f, 0.0700129f, + -0.156443f, -0.13340846f, 0.10764159f, -0.014803357f, 0.046525866f, 0.015691683f, -0.1869241f, + 0.1004442f, -0.4885978f, -0.7585998f, -0.047841772f, -0.07570776f, 0.0471261f, 0.24483289f, + -0.16554686f, -0.1250152f, -0.15132052f, -0.08515984f, 0.14412321f, -0.1030291f, -0.2780918f, + 0.05803944f, -0.10257156f, -0.4341917f, -0.13150966f, -0.53996617f, -0.15628646f, 0.059058204f, + -0.11976162f, -0.022163756f, -0.13519828f, -0.20148787f, 0.16934697f, -0.14327072f, -0.2129095f, + + -0.107836396f, -0.0819309f, -0.06148723f, -0.0063935146f, -0.02425649f, -0.056219954f, -0.06095987f, + -0.14403576f, -0.025357183f, -0.15828207f, 0.012748428f, -0.16061643f, -0.03419252f, -0.05130991f, + -0.109983265f, -0.08312916f, -0.07035978f, -0.008285124f, -0.10610263f, -0.01489019f, -0.106886685f, + -0.007659614f, -0.2947925f, -0.09132287f, -0.040577132f, 0.089866154f, -0.24528673f, -0.055424154f, + 0.13783869f, 0.023674607f, -0.10545369f, -0.20873478f, -0.4685722f, 0.09418375f, -0.06684458f, + 0.0410614f, 0.04018917f, -0.15845582f, 0.06580096f, 0.070554025f, -0.19462511f, -0.03526502f, + + -0.02956047f, -0.16035908f, -0.0638171f, -0.261022f, -0.022948403f, 0.08353848f, -0.041173913f, + 0.04770004f, 0.091520615f, 0.006987013f, -0.39962748f, 0.23266485f, -0.32719564f, -0.12885109f, + -0.29559937f, -0.08031146f, 0.76168066f, 0.0009028502f, -0.4091536f, -0.14801738f, -0.17058557f, + -0.05754847f, 0.2955231f, -0.089874476f, 0.17254886f, -0.13203058f, -0.007648442f, 0.010943003f, + 0.04123217f, 0.26074114f, -0.24313056f, 0.1008903f, -0.26472318f, 0.01998391f, -0.03422378f, + -0.024659738f, 0.033793047f, -0.1998924f, -0.110185415f, 0.10620246f, -0.3435271f, 0.019390412f, + + 0.21691665f, -0.26076952f, -0.5040901f, 0.28383943f, -0.34750903f, -0.32484284f, -0.01734912f, + -0.08909689f, -0.0466362f, 0.21648785f, 0.06733417f, 0.009496197f, 0.18728223f, -0.35110205f, + -0.04908372f, -0.36729553f, -0.346236f, -0.13589534f, -0.16435221f, -0.16853788f, 0.12264759f, + -0.019215636f, -0.38316554f, 0.35669535f, -0.56980205f, -0.059346225f, 0.15008381f, -0.1751053f, + 0.059508912f, 0.116622455f, -0.32607535f, -0.22282779f, -0.29149055f, -0.3829086f, 0.15905643f, + -0.077926554f, 0.06549884f, -0.09004557f, -0.15897253f, 0.26810864f, -0.08931713f, -0.047756508f, + + -0.14657992f, 0.43070868f, -0.021787114f, -0.4532621f, 0.092385404f, -0.30126676f, -0.24893704f, + -0.10896815f, -0.14514503f, -0.21353528f, 0.018723361f, 0.037694372f, 0.11514955f, 0.13013864f, + -0.25713888f, -0.056000195f, -0.3505367f, 0.0836427f, -0.032017898f, -0.26742116f, -0.14740711f, + -0.13330215f, -0.18958306f, -0.08968873f, 0.014723815f, -0.20343366f, 0.3098968f, 0.114284225f, + -0.026738256f, -0.14110464f, -0.054464605f, -0.17529932f, -0.0030034669f, -0.050670102f, -0.04016705f, + -0.062238634f, -0.04886609f, -0.042247344f, -0.12185234f, 0.0357792f, -0.10265522f, -0.116296895f, + + -0.1035416f, -0.09126053f, 0.20045105f, 0.12366664f, 0.05460281f, 0.09944453f, -0.055443168f, + -0.09767935f, -0.040166672f, -0.01716708f, 0.020299219f, 0.02864775f, -0.07159522f, -0.04354491f, + -0.1390779f, -0.13270372f, 0.02992779f, -0.025869183f, 0.12530258f, 0.05101595f, -0.07891131f, + -0.1051311f, -0.093200594f, -0.10368025f, 0.047598884f, -0.12069465f, -0.098738566f, -0.042393237f, + -0.08531736f, -0.051284637f, -0.04354899f, -0.06810297f, -0.083224006f, -0.11702064f, -0.08514082f, + -0.06071842f, -0.07496775f, -0.03626109f, -0.07785503f, -0.07243007f, -0.041736744f, -0.052593358f}; TestConvTransposeOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, OpTester::ExpectResult::kExpectSuccess, "", @@ -1045,7 +1086,7 @@ TEST(ConvTransposeTest, ConvTranspose_1D_AsymmetricPads) { vector W = {1.0f, 1.0f, 1.0f, 1.0f}; vector W_shape = {1, 2, 2}; vector Y_shape = {1, 2, 4}; - auto expected_vals = {3.0f, 5.0f, 7.0f, 4.0f, 3.0f, 5.0f, 7.0f, 4.0f}; + vector expected_vals = {3.0f, 5.0f, 7.0f, 4.0f, 3.0f, 5.0f, 7.0f, 4.0f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kQnnExecutionProvider}); @@ -1068,7 +1109,7 @@ TEST(ConvTransposeTest, ConvTranspose_1D_AutoPad_SameUpper) { vector W = {1.0f, 1.0f, 1.0f, 1.0f}; vector W_shape = {1, 2, 2}; vector Y_shape = {1, 2, 4}; - auto expected_vals = {1.0f, 3.0f, 5.0f, 7.0f, 1.0f, 3.0f, 5.0f, 7.0f}; + vector expected_vals = {1.0f, 3.0f, 5.0f, 7.0f, 1.0f, 3.0f, 5.0f, 7.0f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, OpTester::ExpectResult::kExpectSuccess, "", @@ -1092,7 +1133,7 @@ TEST(ConvTransposeTest, ConvTranspose_1D_AutoPad_SameLower) { vector W = {1.0f, 1.0f, 1.0f, 1.0f}; vector W_shape = {1, 2, 2}; vector Y_shape = {1, 2, 4}; - auto expected_vals = {3.0f, 5.0f, 7.0f, 4.0f, 3.0f, 5.0f, 7.0f, 4.0f}; + vector expected_vals = {3.0f, 5.0f, 7.0f, 4.0f, 3.0f, 5.0f, 7.0f, 4.0f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, OpTester::ExpectResult::kExpectSuccess, "", @@ -1125,19 +1166,19 @@ TEST(ConvTransposeTest, ConvTranspose_AutoPad_with_non_default_strides) { 1.0f, 1.0f, 1.0f}; vector W_shape = {1, 2, 3, 3}; - auto expected_vals = {0.0f, 0.0f, 1.0f, 1.0f, 3.0f, 2.0f, - 0.0f, 0.0f, 1.0f, 1.0f, 3.0f, 2.0f, - 3.0f, 3.0f, 8.0f, 5.0f, 12.0f, 7.0f, - 3.0f, 3.0f, 7.0f, 4.0f, 9.0f, 5.0f, - 9.0f, 9.0f, 20.0f, 11.0f, 24.0f, 13.0f, - 6.0f, 6.0f, 13.0f, 7.0f, 15.0f, 8.0f, - - 0.0f, 0.0f, 1.0f, 1.0f, 3.0f, 2.0f, - 0.0f, 0.0f, 1.0f, 1.0f, 3.0f, 2.0f, - 3.0f, 3.0f, 8.0f, 5.0f, 12.0f, 7.0f, - 3.0f, 3.0f, 7.0f, 4.0f, 9.0f, 5.0f, - 9.0f, 9.0f, 20.0f, 11.0f, 24.0f, 13.0f, - 6.0f, 6.0f, 13.0f, 7.0f, 15.0f, 8.0f}; + vector expected_vals = {0.0f, 0.0f, 1.0f, 1.0f, 3.0f, 2.0f, + 0.0f, 0.0f, 1.0f, 1.0f, 3.0f, 2.0f, + 3.0f, 3.0f, 8.0f, 5.0f, 12.0f, 7.0f, + 3.0f, 3.0f, 7.0f, 4.0f, 9.0f, 5.0f, + 9.0f, 9.0f, 20.0f, 11.0f, 24.0f, 13.0f, + 6.0f, 6.0f, 13.0f, 7.0f, 15.0f, 8.0f, + + 0.0f, 0.0f, 1.0f, 1.0f, 3.0f, 2.0f, + 0.0f, 0.0f, 1.0f, 1.0f, 3.0f, 2.0f, + 3.0f, 3.0f, 8.0f, 5.0f, 12.0f, 7.0f, + 3.0f, 3.0f, 7.0f, 4.0f, 9.0f, 5.0f, + 9.0f, 9.0f, 20.0f, 11.0f, 24.0f, 13.0f, + 6.0f, 6.0f, 13.0f, 7.0f, 15.0f, 8.0f}; vector Y_shape = {1, 2, 6, 6}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, @@ -1170,7 +1211,7 @@ TEST(ConvTransposeTest, SharedPrepackedWeights) { W.push_back(1.0f); test.AddInput("W", {6, 3, 3, 3}, W, true); // Trigger pre-packing - auto expected_vals = { + vector expected_vals = { 12.0f, 18.0f, 18.0f, From f70a934d1eaddf1380efa87591ce4a81d432f493 Mon Sep 17 00:00:00 2001 From: wejoncy Date: Sun, 29 Sep 2024 02:44:11 -0700 Subject: [PATCH 32/39] f --- .../providers/cpu/nn/conv_transpose_op_test.cc | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc b/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc index 96dd7ad63040a..99d814a0e2a15 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc @@ -214,7 +214,7 @@ TYPED_TEST(ConvTransposeTest, ConvTranspose_2D_C2) { } } -TYPED_TEST(ConvTransposeTest, ConvTranspose_2D_Bias_1) { +TEST(ConvTransposeTest, ConvTranspose_2D_Bias_1) { ConvTransposeOpAttributes attrs = { vector{3, 3}, // kernel_shape vector{0, 0}, // output_padding @@ -243,19 +243,7 @@ TYPED_TEST(ConvTransposeTest, ConvTranspose_2D_Bias_1) { 0.07770107f, -0.09561026f, 0.13388641f, 0.30945939f, 0.14015588f, 0.13079405f, -0.00488365f, -0.06758944f, 0.45621645f, 0.01566098f, 0.00703105f, 0.12956856f, 0.0103332f, 0.04221053f, -0.21318194f}; - if constexpr (std::is_same::value) { - TestConvTransposeOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); - } else { - vector X_fp16(X.size()); - ConvertFloatToMLFloat16(X.data(), X_fp16.data(), X.size()); - vector W_fp16(W.size()); - ConvertFloatToMLFloat16(W.data(), W_fp16.data(), W.size()); - std::vector B_fp16(B.size()); - ConvertFloatToMLFloat16(B.data(), B_fp16.data(), B.size()); - std::vector expected_vals_fp16(expected_vals.size()); - ConvertFloatToMLFloat16(expected_vals.data(), expected_vals_fp16.data(), expected_vals.size()); - TestConvTransposeOp(attrs, {X_fp16, W_fp16, B_fp16}, {X_shape, W_shape}, expected_vals_fp16, Y_shape); - } + TestConvTransposeOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); } TEST(ConvTransposeTest, ConvTranspose_2D_Bias_2) { From 2203294894eaa44c0ded009bc2599ee6c37f7f5e Mon Sep 17 00:00:00 2001 From: wejoncy Date: Mon, 30 Sep 2024 11:02:19 +0800 Subject: [PATCH 33/39] Update onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc Co-authored-by: Scott McKay --- onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc b/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc index f139d24993ed8..6a2c8f46143e5 100644 --- a/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc @@ -316,8 +316,8 @@ TYPED_TEST(TensorOpTest, DepthToSpaceTest_3) { test.AddOutput("output", {2, 3, 6, 4}, result_fp16); test.AddInput("input", {N, C, H, W}, X_fp16); } - // TODO: Test is flaky on QNN EP (CPU backend). Reneable when the QnnCPUBackendTests.DISABLED_SpaceToDepth_Flaky test - // is fixed. + // TODO: Test is flaky on QNN EP (CPU backend). + // Re-enable when the QnnCPUBackendTests.DISABLED_SpaceToDepth_Flaky test is fixed. test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kQnnExecutionProvider}); } From 6d709d00a0316aa49eccaba4a8907e3a31a35249 Mon Sep 17 00:00:00 2001 From: wejoncy Date: Sun, 29 Sep 2024 21:37:24 -0700 Subject: [PATCH 34/39] more UTS --- .../coreml/builders/impl/slice_op_builder.cc | 14 ++-- .../providers/cpu/tensor/concat_op_test.cc | 52 +++++++++----- .../providers/cpu/tensor/slice_op.test.cc | 71 +++++++++++++------ .../providers/cpu/tensor/split_op_test.cc | 31 ++++++++ .../providers/cpu/tensor/transpose_test.cc | 32 ++++++--- 5 files changed, 151 insertions(+), 49 deletions(-) diff --git a/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc index 51fc3f2c11c73..4748430743fe3 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc @@ -144,7 +144,7 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const } } - // Only int32 and float are supported by CoreML slice_by_index. + // Int32, float and float16 are supported by CoreML slice_by_index. // We convert any int64 model input to int32 when running the CoreML model for the partition. // Any other integer data created at runtime is the output from CoreML operations, and should int32 not int64. // Based on that, we assume that the actual input when running will be int32, so we override the output data @@ -214,15 +214,21 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const return Status::OK(); } -bool SliceOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& /*input_params*/, +bool SliceOpBuilder::HasSupportedInputsImpl(const Node& node, + [[maybe_unused]] const OpBuilderInputParams& input_params, const logging::Logger& logger) const { int32_t input_type; if (!GetType(*node.InputDefs()[0], input_type, logger)) { return false; } - if (input_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT && - input_type != ONNX_NAMESPACE::TensorProto_DataType_INT64) { +#ifdef COREML_ENABLE_MLPROGRAM + if (input_params.create_mlprogram && input_params.coreml_version >= 7 && + input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + } else +#endif + if (input_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT && + input_type != ONNX_NAMESPACE::TensorProto_DataType_INT64) { LOGS(logger, VERBOSE) << "[" << node.OpType() << "] Input type: [" << input_type << "] is not supported"; return false; } diff --git a/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc b/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc index 98f57f4573540..48d6ad081ac19 100644 --- a/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc @@ -7,6 +7,13 @@ namespace onnxruntime { namespace test { +template +class ConcatOpTest : public ::testing::Test { +}; + +using ConcatOpTestTypes = ::testing::Types; +TYPED_TEST_SUITE(ConcatOpTest, ConcatOpTestTypes); + // Some of the tests can't run on TensorrtExecutionProvider because of unsupported data types or limits // in its parser: axis >=0 && axis < nbDims. Those Tests will fallback to other EPs @@ -68,34 +75,45 @@ TEST(ConcatOpTest, Concat1D_2) { kQnnExecutionProvider}); // QNN: not support dynamic shape tensor } -TEST(ConcatOpTest, Concat2D_1) { +template +std::vector GetTypedArray(std::vector inputs, T v = T(0.f)) { + if constexpr (std::is_same::value) { + return inputs; + } else { + std::vector inputs_fp16(inputs.size()); + ConvertFloatToMLFloat16(inputs.data(), inputs_fp16.data(), inputs.size()); + return inputs_fp16; + } +} + +TYPED_TEST(ConcatOpTest, Concat2D_1) { OpTester test("Concat"); test.AddAttribute("axis", int64_t{0}); std::vector dims{1, 4}; - test.AddInput("input1", dims, {11.0f, 12.0f, 13.0f, 14.0f}); - test.AddInput("input2", dims, {21.0f, 22.0f, 23.0f, 24.0f}); - test.AddInput("input3", dims, {31.0f, 32.0f, 33.0f, 34.0f}); - test.AddOutput("concat_result", {3, 4}, - {11.0f, 12.0f, 13.0f, 14.0f, - 21.0f, 22.0f, 23.0f, 24.0f, - 31.0f, 32.0f, 33.0f, 34.0f}); + test.AddInput("input1", dims, GetTypedArray({11.0f, 12.0f, 13.0f, 14.0f})); + test.AddInput("input2", dims, GetTypedArray({21.0f, 22.0f, 23.0f, 24.0f})); + test.AddInput("input3", dims, GetTypedArray({31.0f, 32.0f, 33.0f, 34.0f})); + test.AddOutput("concat_result", {3, 4}, + GetTypedArray({11.0f, 12.0f, 13.0f, 14.0f, + 21.0f, 22.0f, 23.0f, 24.0f, + 31.0f, 32.0f, 33.0f, 34.0f})); test.Run(); } -TEST(ConcatOpTest, Concat2D_2) { +TYPED_TEST(ConcatOpTest, Concat2D_2) { OpTester test("Concat"); test.AddAttribute("axis", int64_t{1}); std::vector dims{4, 1}; - test.AddInput("input1", dims, {11.0f, 21.0f, 31.0f, 41.0f}); - test.AddInput("input2", {4, 2}, {12.0f, 13.0f, 22.0f, 23.0f, 32.0f, 33.0f, 42.0f, 43.0f}); - test.AddInput("input3", dims, {14.0f, 24.0f, 34.0f, 44.0f}); - test.AddOutput("concat_result", {4, 4}, - {11.0f, 12.0f, 13.0f, 14.0f, - 21.0f, 22.0f, 23.0f, 24.0f, - 31.0f, 32.0f, 33.0f, 34.0f, - 41.0f, 42.0f, 43.0f, 44.0f}); + test.AddInput("input1", dims, GetTypedArray({11.0f, 21.0f, 31.0f, 41.0f})); + test.AddInput("input2", {4, 2}, GetTypedArray({12.0f, 13.0f, 22.0f, 23.0f, 32.0f, 33.0f, 42.0f, 43.0f})); + test.AddInput("input3", dims, GetTypedArray({14.0f, 24.0f, 34.0f, 44.0f})); + test.AddOutput("concat_result", {4, 4}, + GetTypedArray({11.0f, 12.0f, 13.0f, 14.0f, + 21.0f, 22.0f, 23.0f, 24.0f, + 31.0f, 32.0f, 33.0f, 34.0f, + 41.0f, 42.0f, 43.0f, 44.0f})); test.Run(); } diff --git a/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc b/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc index 83b308b57f26b..0285ec01d5869 100644 --- a/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc +++ b/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc @@ -263,17 +263,33 @@ TEST(SliceTest, Slice3D) { 332.0f, 333.0f}); } -template +template +static std::vector GetTypedArray(std::vector inputs, T v = T(0.f)) { + std::vector inputs_T(inputs.size()); + if constexpr (std::is_same::value) { + return inputs; + } else if constexpr (std::is_integral_v) { + for (size_t i = 0; i < inputs.size(); i++) { + inputs_T[i] = static_cast(inputs[i]); + } + return inputs_T; + } else { + ConvertFloatToMLFloat16(inputs.data(), inputs_T.data(), inputs.size()); + return inputs_T; + } +} + +template static void TestSlice1DIntData() { - static_assert(std::is_integral_v); - RunSliceTest({6}, - {0, 1, 2, 3, 4, 5}, - {2}, - {4}, - {0}, - {}, - {2}, - {2, 3}); + // static_assert(std::is_integral_v); + RunSliceTest({6}, + GetTypedArray({0.f, 1.f, 2.f, 3.f, 4.f, 5.f}), + {2}, + {4}, + {0}, + {}, + {2}, + GetTypedArray({2.f, 3.f})); } TEST(SliceTest, Slice1D_Int32) { @@ -284,6 +300,21 @@ TEST(SliceTest, Slice1D_Int64) { TestSlice1DIntData(); } +TEST(SliceTest, Slice1D_Float) { + TestSlice1DIntData(); +} + +TEST(SliceTest, Slice1D_Float16) { + TestSlice1DIntData(); +} + +template +class SliceTest : public ::testing::Test { +}; + +using SliceTestTypes = ::testing::Types; +TYPED_TEST_SUITE(SliceTest, SliceTestTypes); + TEST(SliceTest, Slice1D_String) { RunSliceTest({6}, {"0", "1", "2", "3", "4", "5"}, @@ -296,16 +327,16 @@ TEST(SliceTest, Slice1D_String) { } // Only Slice V10 can run the following tests -TEST(SliceTest, Slice1D_WithPositiveSteps) { - RunSliceTest({6}, - {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, - {0}, - {6}, - {0}, - {2}, - {3}, - {0.0f, 2.0f, 4.0f}, - true); +TYPED_TEST(SliceTest, Slice1D_WithPositiveSteps) { + RunSliceTest({6}, + GetTypedArray({0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f}), + {0}, + {6}, + {0}, + {2}, + {3}, + GetTypedArray({0.0f, 2.0f, 4.0f}), + true); } // In numpy: diff --git a/onnxruntime/test/providers/cpu/tensor/split_op_test.cc b/onnxruntime/test/providers/cpu/tensor/split_op_test.cc index 066302a4a37d1..35424fbdc90cf 100644 --- a/onnxruntime/test/providers/cpu/tensor/split_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/split_op_test.cc @@ -178,6 +178,17 @@ TEST(SplitOperatorTest, Axis0UnequalSplitFloat) { RunTest(axis, splits, input, outputs, {kTensorrtExecutionProvider}, false, true); } +template +std::vector GetTypedArray(std::vector inputs, T v = T(0.f)) { + if constexpr (std::is_same::value) { + return inputs; + } else { + std::vector inputs_fp16(inputs.size()); + ConvertFloatToMLFloat16(inputs.data(), inputs_fp16.data(), inputs.size()); + return inputs_fp16; + } +} + TEST(SplitOperatorTest, Axis0UnequalSplitString) { constexpr int64_t axis = 0; std::vector outputs; @@ -222,6 +233,26 @@ TEST(SplitOperatorTest, Axis1EqualSplitFloat) { RunTest(axis, {}, input, outputs, {kTensorrtExecutionProvider}); } +TEST(SplitOperatorTest, Axis1EqualSplitFloat16) { + constexpr int64_t axis = 1; + std::vector> outputs; + + // input shape and data + ShapeAndData input = {{2, 4}, + GetTypedArray({1.f, 2.f, 3.f, 4.f, + 5.f, 6.f, 7.f, 8.f})}; + + outputs.push_back({{2, 2}, + GetTypedArray({1.f, 2.f, + 5.f, 6.f})}); + + outputs.push_back({{2, 2}, + GetTypedArray({3.f, 4.f, + 7.f, 8.f})}); + RunTest(axis, {}, input, outputs, {kTensorrtExecutionProvider}, false, true); + RunTest(axis, {}, input, outputs, {kTensorrtExecutionProvider}); +} + TEST(SplitOperatorTest, Axis1EqualSplitString) { constexpr int64_t axis = 1; std::vector outputs; diff --git a/onnxruntime/test/providers/cpu/tensor/transpose_test.cc b/onnxruntime/test/providers/cpu/tensor/transpose_test.cc index 01dba55ceb8ed..28ee044df3f6a 100644 --- a/onnxruntime/test/providers/cpu/tensor/transpose_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/transpose_test.cc @@ -12,6 +12,13 @@ namespace onnxruntime { namespace test { +template +class TransposeOpTest : public ::testing::Test { +}; + +using TransposeOpTestTypes = ::testing::Types; +TYPED_TEST_SUITE(TransposeOpTest, TransposeOpTestTypes); + TEST(TransposeOpTest, IsTransposeReshapeTest) { std::vector input_dims{1, 2, 3, 4, 1}; std::vector perm{0, 1, 2, 3, 4}; @@ -62,18 +69,27 @@ void TransposeTest(const std::vector& input_shape, } } +template +std::vector GetTypedArray(std::vector inputs, T v = T(0.f)) { + if constexpr (std::is_same::value) { + return inputs; + } else { + std::vector inputs_fp16(inputs.size()); + ConvertFloatToMLFloat16(inputs.data(), inputs_fp16.data(), inputs.size()); + return inputs_fp16; + } +} + // Test 2 dimensional transpose, with no permutation attribute specified -TEST(TransposeOpTest, TwoDimNoAttr) { +TYPED_TEST(TransposeOpTest, TwoDimNoAttr) { std::vector input_shape({2, 3}); - std::vector input_vals = { - 1.0f, 2.0f, 3.0f, - 4.0f, 5.0f, 6.0f}; + std::vector input_vals = GetTypedArray({1.0f, 2.0f, 3.0f, + 4.0f, 5.0f, 6.0f}); std::vector expected_shape({3, 2}); - std::vector expected_vals = { - 1.0f, 4.0f, - 2.0f, 5.0f, - 3.0f, 6.0f}; + std::vector expected_vals = GetTypedArray({1.0f, 4.0f, + 2.0f, 5.0f, + 3.0f, 6.0f}); TransposeTest(input_shape, input_vals, nullptr, expected_shape, expected_vals, {kTensorrtExecutionProvider}, {7, 21}); // TensorRT: SegFault error From c80e312829ad532be2b4fc2862f0c169edea8a50 Mon Sep 17 00:00:00 2001 From: wejoncy Date: Sun, 29 Sep 2024 22:19:34 -0700 Subject: [PATCH 35/39] add doc for slice --- .../providers/coreml/builders/impl/slice_op_builder.cc | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc index 4748430743fe3..194697eb2e985 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc @@ -223,11 +223,16 @@ bool SliceOpBuilder::HasSupportedInputsImpl(const Node& node, } #ifdef COREML_ENABLE_MLPROGRAM +// The [Doc](https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.tensor_transformation.slice_by_index) +// says slice_by_index is support fp16 by ML Program. It's something wrong and it requires coreml version >= 7 otherwise +// only float is supported. +// refs 1:https://github.com/apple/coremltools/blob/89d058ffdcb0b39a03031782d8a448b6889ac425/coremltools/converters/mil/mil/ops/defs/tensor_transformation.py#L515 +// refs 2:https://github.com/apple/coremltools/blob/c3ea4cf56fef1176417246c1b85363417f3e713d/coremltools/converters/mil/mil/ops/defs/iOS15/tensor_transformation.py#L495 if (input_params.create_mlprogram && input_params.coreml_version >= 7 && input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { } else -#endif - if (input_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT && +#endif // nolint + if (input_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT && input_type != ONNX_NAMESPACE::TensorProto_DataType_INT64) { LOGS(logger, VERBOSE) << "[" << node.OpType() << "] Input type: [" << input_type << "] is not supported"; return false; From 8b2bf8e368426e8131d8565a67494d9f84757f21 Mon Sep 17 00:00:00 2001 From: wejoncy Date: Sun, 29 Sep 2024 22:25:35 -0700 Subject: [PATCH 36/39] code spell fix --- .../providers/cpu/tensor/space_depth_ops_test.cc | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc b/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc index 6a2c8f46143e5..4954b82690e0f 100644 --- a/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc @@ -44,8 +44,8 @@ TEST(TensorOpTest, SpaceToDepthTest_1) { 3.1f, 3.3f}; test.AddOutput("output", {N, C * blocksize * blocksize, H / blocksize, W / blocksize}, result); - // TODO: Test is flaky on QNN EP (CPU backend). Reneable when the QnnCPUBackendTests.DISABLED_SpaceToDepth_Flaky test - // is fixed. + // TODO: Test is flaky on QNN EP (CPU backend). + // Re-enable when the QnnCPUBackendTests.DISABLED_SpaceToDepth_Flaky test is fixed. test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kQnnExecutionProvider}); } @@ -111,8 +111,8 @@ TEST(TensorOpTest, SpaceToDepthTest_2) { 88., 103., 106., 68., 71., 86., 89., 104., 107.}; test.AddOutput("output", {2, 27, 1, 2}, result); - // TODO: Test is flaky on QNN EP (CPU backend). Reneable when the QnnCPUBackendTests.DISABLED_SpaceToDepth_Flaky2 - // test is fixed. + // TODO: Test is flaky on QNN EP (CPU backend). + // Re-enable when the QnnCPUBackendTests.DISABLED_SpaceToDepth_Flaky2 test is fixed. test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kQnnExecutionProvider}); } @@ -372,8 +372,8 @@ TYPED_TEST(TensorOpTest, DepthToSpaceTest_4) { test.AddOutput("output", {2, 3, 6, 4}, result_fp16); } - // TODO: Test is flaky on QNN EP (CPU backend). Reneable when the QnnCPUBackendTests.DISABLED_SpaceToDepth_Flaky test - // is fixed. + // TODO: Test is flaky on QNN EP (CPU backend). + // Re-enable when the QnnCPUBackendTests.DISABLED_SpaceToDepth_Flaky2 test is fixed. test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kQnnExecutionProvider}); } @@ -409,8 +409,8 @@ TYPED_TEST(TensorOpTest, DepthToSpaceTest_5) { test.AddInput("input", {N, C, H, W}, X_fp16); test.AddOutput("output", {1, 1, 4, 6}, result_fp16); } - // TODO: Test is flaky on QNN EP (CPU backend). Reneable when the QnnCPUBackendTests.DISABLED_SpaceToDepth_Flaky test - // is fixed. + // TODO: Test is flaky on QNN EP (CPU backend). + // Re-enable when the QnnCPUBackendTests.DISABLED_SpaceToDepth_Flaky2 test is fixed. test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kQnnExecutionProvider}); } From 3326719deca547ef4bf2757ea763a702127b7cf1 Mon Sep 17 00:00:00 2001 From: wejoncy Date: Sun, 29 Sep 2024 22:33:15 -0700 Subject: [PATCH 37/39] format --- .../coreml/builders/impl/slice_op_builder.cc | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc index 194697eb2e985..f795feb886b38 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc @@ -223,20 +223,20 @@ bool SliceOpBuilder::HasSupportedInputsImpl(const Node& node, } #ifdef COREML_ENABLE_MLPROGRAM -// The [Doc](https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.tensor_transformation.slice_by_index) -// says slice_by_index is support fp16 by ML Program. It's something wrong and it requires coreml version >= 7 otherwise -// only float is supported. -// refs 1:https://github.com/apple/coremltools/blob/89d058ffdcb0b39a03031782d8a448b6889ac425/coremltools/converters/mil/mil/ops/defs/tensor_transformation.py#L515 -// refs 2:https://github.com/apple/coremltools/blob/c3ea4cf56fef1176417246c1b85363417f3e713d/coremltools/converters/mil/mil/ops/defs/iOS15/tensor_transformation.py#L495 + // The [Doc](https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.tensor_transformation.slice_by_index) + // says slice_by_index is support fp16 by ML Program. It's something wrong and it requires coreml version >= 7 otherwise + // only float is supported. + // refs 1:https://github.com/apple/coremltools/blob/89d058ffdcb0b39a03031782d8a448b6889ac425/coremltools/converters/mil/mil/ops/defs/tensor_transformation.py#L515 + // refs 2:https://github.com/apple/coremltools/blob/c3ea4cf56fef1176417246c1b85363417f3e713d/coremltools/converters/mil/mil/ops/defs/iOS15/tensor_transformation.py#L495 if (input_params.create_mlprogram && input_params.coreml_version >= 7 && input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { } else -#endif // nolint - if (input_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT && - input_type != ONNX_NAMESPACE::TensorProto_DataType_INT64) { - LOGS(logger, VERBOSE) << "[" << node.OpType() << "] Input type: [" << input_type << "] is not supported"; - return false; - } +#endif // nolint + if (input_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT && + input_type != ONNX_NAMESPACE::TensorProto_DataType_INT64) { + LOGS(logger, VERBOSE) << "[" << node.OpType() << "] Input type: [" << input_type << "] is not supported"; + return false; + } return true; } From 0dccf6427b342a41de6f90d91c2352afee1c479f Mon Sep 17 00:00:00 2001 From: wejoncy Date: Mon, 30 Sep 2024 00:07:20 -0700 Subject: [PATCH 38/39] address review comments --- .../coreml/builders/impl/slice_op_builder.cc | 8 ++--- .../cpu/nn/conv_transpose_op_test.cc | 29 ++++++++++--------- .../providers/cpu/tensor/concat_op_test.cc | 2 +- .../providers/cpu/tensor/slice_op.test.cc | 2 +- .../providers/cpu/tensor/split_op_test.cc | 2 +- .../providers/cpu/tensor/transpose_test.cc | 2 +- 6 files changed, 24 insertions(+), 21 deletions(-) diff --git a/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc index f795feb886b38..cb0f7cc0208e9 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc @@ -224,10 +224,10 @@ bool SliceOpBuilder::HasSupportedInputsImpl(const Node& node, #ifdef COREML_ENABLE_MLPROGRAM // The [Doc](https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.tensor_transformation.slice_by_index) - // says slice_by_index is support fp16 by ML Program. It's something wrong and it requires coreml version >= 7 otherwise - // only float is supported. - // refs 1:https://github.com/apple/coremltools/blob/89d058ffdcb0b39a03031782d8a448b6889ac425/coremltools/converters/mil/mil/ops/defs/tensor_transformation.py#L515 - // refs 2:https://github.com/apple/coremltools/blob/c3ea4cf56fef1176417246c1b85363417f3e713d/coremltools/converters/mil/mil/ops/defs/iOS15/tensor_transformation.py#L495 + // says ML Program slice_by_index supports fp16 in CoreML 5 (iOS 15). + // It's incorrect and CoreML 6+ (iOS16, CoreML spec version >= 7) is required otherwise only float is supported. + // CoreML 5:https://github.com/apple/coremltools/blob/89d058ffdcb0b39a03031782d8a448b6889ac425/coremltools/converters/mil/mil/ops/defs/tensor_transformation.py#L515 + // CoreML 6:https://github.com/apple/coremltools/blob/c3ea4cf56fef1176417246c1b85363417f3e713d/coremltools/converters/mil/mil/ops/defs/iOS15/tensor_transformation.py#L495 if (input_params.create_mlprogram && input_params.coreml_version >= 7 && input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { } else diff --git a/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc b/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc index 99d814a0e2a15..29525f89ef544 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc @@ -123,6 +123,17 @@ TEST(ConvTransposeTest, ConvTranspose_1D) { TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } +template +static std::vector GetTypedArray(std::vector inputs, [[maybe_unused]] T v = T(0.f)) { + if constexpr (std::is_same::value) { + return inputs; + } else { + std::vector inputs_fp16(inputs.size()); + ConvertFloatToMLFloat16(inputs.data(), inputs_fp16.data(), inputs.size()); + return inputs_fp16; + } +} + TYPED_TEST(ConvTransposeTest, ConvTranspose_2D_outputpadding_strides2) { ConvTransposeOpAttributes attrs = { vector{3, 3}, // kernel_shape @@ -201,20 +212,11 @@ TYPED_TEST(ConvTransposeTest, ConvTranspose_2D_C2) { 0.58575004f, 1.25774109f, 1.23472511f, 0.77670550f, 0.25844323f, 0.88953220f, 0.77098041f, 0.27468451f}; - if constexpr (std::is_same::value) { - TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); - } else { - vector X_fp16(X.size()); - ConvertFloatToMLFloat16(X.data(), X_fp16.data(), X.size()); - vector W_fp16(W.size()); - ConvertFloatToMLFloat16(W.data(), W_fp16.data(), W.size()); - std::vector expected_vals_fp16(expected_vals.size()); - ConvertFloatToMLFloat16(expected_vals.data(), expected_vals_fp16.data(), expected_vals.size()); - TestConvTransposeOp(attrs, {X_fp16, W_fp16}, {X_shape, W_shape}, expected_vals_fp16, Y_shape); - } + TestConvTransposeOp(attrs, {GetTypedArray(X), GetTypedArray(W)}, + {X_shape, W_shape}, GetTypedArray(expected_vals), Y_shape); } -TEST(ConvTransposeTest, ConvTranspose_2D_Bias_1) { +TYPED_TEST(ConvTransposeTest, ConvTranspose_2D_Bias_1) { ConvTransposeOpAttributes attrs = { vector{3, 3}, // kernel_shape vector{0, 0}, // output_padding @@ -243,7 +245,8 @@ TEST(ConvTransposeTest, ConvTranspose_2D_Bias_1) { 0.07770107f, -0.09561026f, 0.13388641f, 0.30945939f, 0.14015588f, 0.13079405f, -0.00488365f, -0.06758944f, 0.45621645f, 0.01566098f, 0.00703105f, 0.12956856f, 0.0103332f, 0.04221053f, -0.21318194f}; - TestConvTransposeOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); + TestConvTransposeOp(attrs, {GetTypedArray(X), GetTypedArray(W), GetTypedArray(B)}, + {X_shape, W_shape, B_shape}, GetTypedArray(expected_vals), Y_shape); } TEST(ConvTransposeTest, ConvTranspose_2D_Bias_2) { diff --git a/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc b/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc index 48d6ad081ac19..4a1888a5ca7d6 100644 --- a/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc @@ -76,7 +76,7 @@ TEST(ConcatOpTest, Concat1D_2) { } template -std::vector GetTypedArray(std::vector inputs, T v = T(0.f)) { +static std::vector GetTypedArray(std::vector inputs, [[maybe_unused]] T v = T(0.f)) { if constexpr (std::is_same::value) { return inputs; } else { diff --git a/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc b/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc index 0285ec01d5869..a32d43f296250 100644 --- a/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc +++ b/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc @@ -264,7 +264,7 @@ TEST(SliceTest, Slice3D) { } template -static std::vector GetTypedArray(std::vector inputs, T v = T(0.f)) { +static std::vector GetTypedArray(std::vector inputs, [[maybe_unused]] T v = T(0.f)) { std::vector inputs_T(inputs.size()); if constexpr (std::is_same::value) { return inputs; diff --git a/onnxruntime/test/providers/cpu/tensor/split_op_test.cc b/onnxruntime/test/providers/cpu/tensor/split_op_test.cc index 35424fbdc90cf..48872404f08bd 100644 --- a/onnxruntime/test/providers/cpu/tensor/split_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/split_op_test.cc @@ -179,7 +179,7 @@ TEST(SplitOperatorTest, Axis0UnequalSplitFloat) { } template -std::vector GetTypedArray(std::vector inputs, T v = T(0.f)) { +std::vector GetTypedArray(std::vector inputs, [[maybe_unused]] T v = T(0.f)) { if constexpr (std::is_same::value) { return inputs; } else { diff --git a/onnxruntime/test/providers/cpu/tensor/transpose_test.cc b/onnxruntime/test/providers/cpu/tensor/transpose_test.cc index 28ee044df3f6a..3b46dc3f5d6a2 100644 --- a/onnxruntime/test/providers/cpu/tensor/transpose_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/transpose_test.cc @@ -70,7 +70,7 @@ void TransposeTest(const std::vector& input_shape, } template -std::vector GetTypedArray(std::vector inputs, T v = T(0.f)) { +std::vector GetTypedArray(std::vector inputs, [[maybe_unused]] T v = T(0.f)) { if constexpr (std::is_same::value) { return inputs; } else { From 9ab55d4cafaa9dd56e6804a2ca86f286fb5b45da Mon Sep 17 00:00:00 2001 From: wejoncy Date: Mon, 30 Sep 2024 00:22:54 -0700 Subject: [PATCH 39/39] slice coreml version >= 6 --- .../core/providers/coreml/builders/impl/slice_op_builder.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc index cb0f7cc0208e9..6b3fe75fa592d 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc @@ -228,7 +228,7 @@ bool SliceOpBuilder::HasSupportedInputsImpl(const Node& node, // It's incorrect and CoreML 6+ (iOS16, CoreML spec version >= 7) is required otherwise only float is supported. // CoreML 5:https://github.com/apple/coremltools/blob/89d058ffdcb0b39a03031782d8a448b6889ac425/coremltools/converters/mil/mil/ops/defs/tensor_transformation.py#L515 // CoreML 6:https://github.com/apple/coremltools/blob/c3ea4cf56fef1176417246c1b85363417f3e713d/coremltools/converters/mil/mil/ops/defs/iOS15/tensor_transformation.py#L495 - if (input_params.create_mlprogram && input_params.coreml_version >= 7 && + if (input_params.create_mlprogram && input_params.coreml_version >= 6 && input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { } else #endif // nolint