From 2cfe1f031def2156b66b8ba8392040f44f17ea23 Mon Sep 17 00:00:00 2001 From: wejoncy Date: Mon, 30 Sep 2024 17:56:47 +0800 Subject: [PATCH] [CoreML MLProgram] Support Float16 (1/N) (#22068) ### Description Support Float16 for CoreML MLProgram EP. Operations: "Add", "Mul", "Sub", "Div", "Pow", "Sqrt", "Reciprocal", "Sigmoid", "Tanh", "Relu", "LeakyRelu", "Concat", "GridSample", "GlobalAveragePool", "Clip", "DepthToSpace", "Resize", "Slice", "Conv", "ConvTranspose", "GlobalMaxPool", "Gemm", "MatMul", "AveragePool", "MaxPool", "Reshape", "Split", "Transpose" ### Motivation and Context --------- Co-authored-by: Scott McKay --- .../builders/impl/activation_op_builder.cc | 8 +- .../coreml/builders/impl/base_op_builder.cc | 38 +- .../coreml/builders/impl/base_op_builder.h | 6 +- .../coreml/builders/impl/binary_op_builder.cc | 21 +- .../coreml/builders/impl/builder_utils.cc | 32 + .../coreml/builders/impl/builder_utils.h | 3 + .../coreml/builders/impl/clip_op_builder.cc | 22 +- .../builders/impl/depthtospace_op_builder.cc | 4 +- .../coreml/builders/impl/gemm_op_builder.cc | 63 +- .../builders/impl/gridsample_op_builder.cc | 9 +- .../coreml/builders/impl/slice_op_builder.cc | 25 +- .../coreml/builders/impl/unary_op_builder.cc | 59 +- .../coreml/builders/model_builder.cc | 11 + .../providers/coreml/builders/model_builder.h | 3 +- .../core/providers/coreml/model/model.mm | 36 +- .../cpu/activation/activation_op_test.cc | 6 +- .../cpu/activation/activation_op_test.h | 2 - .../test/providers/cpu/math/clip_test.cc | 32 +- .../cpu/math/element_wise_ops_test.cc | 155 ++-- .../test/providers/cpu/math/gemm_test.cc | 150 +++- .../test/providers/cpu/math/matmul_test.cc | 2 +- .../test/providers/cpu/nn/conv_fp16_test.cc | 20 +- .../cpu/nn/conv_transpose_op_test.cc | 484 +++++++------ .../providers/cpu/nn/pool_fp16_op_test.cc | 4 +- .../test/providers/cpu/nn/pool_op_test.cc | 45 +- .../providers/cpu/tensor/concat_op_test.cc | 52 +- .../providers/cpu/tensor/grid_sample_test.cc | 679 +++++++++--------- .../cpu/tensor/grid_sample_test_gen.py | 60 +- .../providers/cpu/tensor/resize_op_test.cc | 25 +- .../providers/cpu/tensor/slice_op.test.cc | 71 +- .../cpu/tensor/space_depth_ops_test.cc | 79 +- .../providers/cpu/tensor/split_op_test.cc | 31 + .../providers/cpu/tensor/transpose_test.cc | 32 +- onnxruntime/test/util/test_utils.cc | 5 + .../apple/coreml_supported_mlprogram_ops.md | 2 + 35 files changed, 1427 insertions(+), 849 deletions(-) diff --git a/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc index c8670cd546253..5389eb5ab7e95 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc @@ -104,7 +104,13 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, if (add_alpha) { NodeAttrHelper helper(node); const auto alpha = helper.Get("alpha", 0.01f); - AddOperationInput(*op, "alpha", model_builder.AddScalarConstant(op->type(), "alpha", alpha)); + + auto input_dtype = node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + AddOperationInput(*op, "alpha", model_builder.AddScalarConstant(op->type(), "alpha", alpha)); + } else { + AddOperationInput(*op, "alpha", model_builder.AddScalarConstant(op->type(), "alpha", MLFloat16(alpha))); + } } AddOperationOutput(*op, *node.OutputDefs()[0]); diff --git a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc index 2cae85a0a1c8d..f185a80de3cbf 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" @@ -12,6 +13,15 @@ using namespace CoreML::Specification; namespace onnxruntime { namespace coreml { +// Once all ops are supportted FP16, we can remove it. Before that, we keep a set of ops to +// filter suppported ones. +static std::set Float16Ops = { + "Add", "Mul", "Sub", "Div", "Pow", "Sqrt", "Reciprocal", + "Sigmoid", "Tanh", "Relu", "LeakyRelu", "Concat", "GridSample", "GlobalAveragePool", + "Clip", "DepthToSpace", "Resize", "Slice", "Conv", + "ConvTranspose", "GlobalMaxPool", "Gemm", "MatMul", + "AveragePool", "MaxPool", "Reshape", "Split", "Transpose"}; + namespace { // TODO, move this to shared_library bool HasExternalInitializer(const InitializedTensorSet& initializers, const Node& node, @@ -83,8 +93,9 @@ bool BaseOpBuilder::HasSupportedInputs(const Node& node, const OpBuilderInputPar } /* static */ -bool BaseOpBuilder::IsInputFloat(const Node& node, size_t idx, const OpBuilderInputParams& /*input_params*/, - const logging::Logger& logger) { +bool BaseOpBuilder::IsInputDtypeSupport(const Node& node, size_t idx, + [[maybe_unused]] const OpBuilderInputParams& input_params, + const logging::Logger& logger) { if (idx >= node.InputDefs().size()) { LOGS(logger, VERBOSE) << "Input index [" << idx << "] is out of range"; return false; @@ -94,20 +105,33 @@ bool BaseOpBuilder::IsInputFloat(const Node& node, size_t idx, const OpBuilderIn int32_t input_type = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED; - // currently only float is supported - if (!GetType(input, input_type, logger) || input_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { - LOGS(logger, VERBOSE) << "[" << node.OpType() << "] Input type: [" << input_type << "] is not currently supported"; + if (!GetType(input, input_type, logger)) { + LOGS(logger, VERBOSE) << "[" << node.OpType() << "] Get Input type failed"; return false; } - return true; + // float is supported + if (input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + return true; + } + +// only support MLProgram for FP16 +#if defined(COREML_ENABLE_MLPROGRAM) + if (input_params.create_mlprogram && input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 && + Float16Ops.count(node.OpType())) { + return true; + } +#endif + + LOGS(logger, VERBOSE) << "[" << node.OpType() << "] Input type: [" << input_type << "] is not currently supported"; + return false; } bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { // We only check the type of input 0 by default // specific op builder can override this - return IsInputFloat(node, 0, input_params, logger); + return IsInputDtypeSupport(node, 0, input_params, logger); } bool BaseOpBuilder::HasSupportedOpSet(const Node& node, const logging::Logger& logger) const { diff --git a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h index 071008520fbdc..153ae841b238f 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h +++ b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h @@ -32,9 +32,9 @@ class BaseOpBuilder : public IOpBuilder { : allow_empty_tensor_as_input_(allow_empty_tensor_as_input) { } - // currently we only support float - static bool IsInputFloat(const Node& node, size_t idx, const OpBuilderInputParams& input_params, - const logging::Logger& logger); + // currently we support float/float16 + static bool IsInputDtypeSupport(const Node& node, size_t idx, const OpBuilderInputParams& input_params, + const logging::Logger& logger); private: virtual bool IsOpSupportedImpl(const Node& /*node*/, const OpBuilderInputParams& /*input_params*/, diff --git a/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc index fb8e07633621f..8aa2dbae2531c 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc @@ -73,7 +73,7 @@ Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const } else if (op_type == "Sub") { coreml_op_type = "sub"; } else if (op_type == "Div") { - // we only support fp32 currently. when we add support for integers we need to check the type and use + // we support fp32/fp16 currently. when we add support for integers we need to check the type and use // "floor_div" or "real_div" accordingly coreml_op_type = "real_div"; } else if (op_type == "Pow") { @@ -138,9 +138,22 @@ bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderIn const logging::Logger& logger) const { // Add/Sub/Mul/Div spec says inputs must be of the same type. // Pow spec says inputs can be different types. - // We only support float for all of these inputs. - if (!IsInputFloat(node, 0, input_params, logger) || - ((node.OpType() == "Pow") && !IsInputFloat(node, 1, input_params, logger))) { + // We support float/float16 for all of these inputs. + + if (node.OpType() == "Pow") { + const auto& input0 = *node.InputDefs()[0]; + const auto& input1 = *node.InputDefs()[1]; + int32_t input_type0 = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED; + int32_t input_type1 = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED; + if (!GetType(input0, input_type0, logger)) { + return false; + } + if (!GetType(input1, input_type1, logger) || input_type1 != input_type0) { + return false; + } + } + + if (!IsInputDtypeSupport(node, 0, input_params, logger)) { return false; } diff --git a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc index e02186d3aee89..6f9bb35c27d80 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc @@ -96,6 +96,9 @@ Status CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: CreateCoreMLWeight(weight, unpacked_tensor.DataAsSpan()); break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + CreateCoreMLWeight(weight, unpacked_tensor.DataAsSpan()); + break; case ONNX_NAMESPACE::TensorProto_DataType_INT32: CreateCoreMLWeight(weight, unpacked_tensor.DataAsSpan()); break; @@ -114,6 +117,11 @@ void CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, gsl::spanAssign(data.begin(), data.end()); } +void CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, gsl::span data) { + const char* data_byte_ptr = reinterpret_cast(data.data()); + weight.mutable_float16value()->assign(data_byte_ptr, data_byte_ptr + data.size_bytes()); +} + namespace { template void CreateCoreMLWeightConvertingDataToFloats(CoreML::Specification::WeightParams& weight, gsl::span data) { @@ -123,6 +131,15 @@ void CreateCoreMLWeightConvertingDataToFloats(CoreML::Specification::WeightParam [](T v) { return narrow(v); }); *weight.mutable_floatvalue() = std::move(weight_floats); } + +template +void CreateCoreMLWeightConvertingDataToFloat16s(CoreML::Specification::WeightParams& weight, gsl::span data) { + std::vector weight_float16s{}; + weight_float16s.reserve(data.size()); + std::transform(data.begin(), data.end(), std::back_inserter(weight_float16s), + [](T v) { return MLFloat16(float(v)); }); + CreateCoreMLWeight(weight, weight_float16s); +} } // namespace void CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, gsl::span data) { @@ -195,6 +212,13 @@ void CopyDataToTensorValue(MILSpec::TensorValue& tensor_value, gsl::span< tensor_value.mutable_floats()->mutable_values()->Add(data.begin(), data.end()); } +template <> +void CopyDataToTensorValue(MILSpec::TensorValue& tensor_value, gsl::span data) { + const char* begin = reinterpret_cast(data.data()); + const char* end = begin + (data.size() * sizeof(MLFloat16)); + tensor_value.mutable_bytes()->mutable_values()->assign(begin, end); +} + template <> void CopyDataToTensorValue(MILSpec::TensorValue& tensor_value, gsl::span data) { tensor_value.mutable_ints()->mutable_values()->Add(data.begin(), data.end()); @@ -290,6 +314,14 @@ MILSpec::Value CreateScalarTensorValue(const T& data) { // explicit specializations for types we handle so the implementation can be in the .cc file template MILSpec::Value CreateTensorValue(gsl::span data, std::optional> shape); +template MILSpec::Value CreateTensorValue(gsl::span data, + std::optional> shape); +template MILSpec::Value CreateTensorValue(gsl::span data, + std::optional> shape); +template MILSpec::Value CreateTensorValue(gsl::span data, + std::optional> shape); +template MILSpec::Value CreateTensorValue(gsl::span data, + std::optional> shape); template MILSpec::Value CreateScalarTensorValue(const float& data); template MILSpec::Value CreateScalarTensorValue(const int32_t& data); diff --git a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h index 475ce79b0a812..f38afc0ec181d 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h +++ b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h @@ -41,6 +41,9 @@ Status CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, const ONN // Copy the float array to a coreml weight void CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, gsl::span data); +// Copy the MLFloat16 array to a coreml weight +void CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, gsl::span data); + // Copy the int32_t array to a coreml weight void CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, gsl::span data); diff --git a/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc index 41f4041ef1181..bc9e2f10296ed 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc @@ -92,16 +92,30 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, Operation& clip_op = *op; AddOperationInput(clip_op, "x", input_name); + // we already checked it and dtype must be existed. + auto input_dtype = node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); // if min and max were attributes we need to add initializers. otherwise we use the existing inputs const bool min_max_attribs = node.SinceVersion() < 11; - std::string_view min_name = min_max_attribs ? model_builder.AddScalarConstant(clip_op.type(), "min", min) - : node.InputDefs()[1]->Name(); + std::string_view min_name; + if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + min_name = min_max_attribs ? model_builder.AddScalarConstant(clip_op.type(), "min", min) + : node.InputDefs()[1]->Name(); + } else { + min_name = min_max_attribs ? model_builder.AddScalarConstant(clip_op.type(), "min", MLFloat16(min)) + : node.InputDefs()[1]->Name(); + } AddOperationInput(clip_op, "alpha", min_name); if (has_max) { - std::string_view max_name = min_max_attribs ? model_builder.AddScalarConstant(clip_op.type(), "max", max) - : node.InputDefs()[2]->Name(); + std::string_view max_name; + if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + max_name = min_max_attribs ? model_builder.AddScalarConstant(clip_op.type(), "max", max) + : node.InputDefs()[2]->Name(); + } else { + max_name = min_max_attribs ? model_builder.AddScalarConstant(clip_op.type(), "max", MLFloat16(max)) + : node.InputDefs()[2]->Name(); + } AddOperationInput(clip_op, "beta", max_name); } } diff --git a/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc index bec2461ffbc52..ddaa19c7fab18 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc @@ -67,7 +67,9 @@ Status DepthToSpaceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, // we checked shape was static in IsOpSupportedImpl so this should never fail std::vector input_shape; ORT_RETURN_IF_NOT(GetStaticShape(*input_defs[0], input_shape, logger), "Failed to get input shape"); - const int32_t elem_type = static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + auto input_dtype = input_defs[0]->TypeAsProto()->tensor_type().elem_type(); + + const int32_t elem_type = static_cast(input_dtype); // reshape to [b * c // (blocksize ** 2), blocksize, blocksize, h, w] auto reshape1 = model_builder.CreateOperation(node, "reshape", "pre"); diff --git a/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc index 7338fc18fe779..e685c09ef43ca 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc @@ -70,16 +70,17 @@ void GemmOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Nod } } -// This is an internal function, requires input tensor to be 2d float tensor -// TODO, add support of other data types -static Status GetTensorFloatDataTransposed(const ONNX_NAMESPACE::TensorProto& tensor, - std::vector& transposed_data) { +// This is an internal function, requires input tensor to be 2d float/float16 tensor +template +static Status GetTensorDataTransposed(const ONNX_NAMESPACE::TensorProto& tensor, + std::vector& transposed_data) { Initializer unpacked_tensor(tensor); - auto src_data = unpacked_tensor.DataAsSpan(); + const auto src_data = unpacked_tensor.DataAsSpan(); const auto& tensor_shape = tensor.dims(); auto x_t = SafeInt(tensor_shape[0]); auto y_t = SafeInt(tensor_shape[1]); transposed_data.resize(x_t * y_t); + for (size_t x = 0; x < x_t; x++) { for (size_t y = 0; y < y_t; y++) { transposed_data[y * x_t + x] = src_data[x * y_t + y]; @@ -121,8 +122,9 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N // B is {K, N} in ONNX spec by default, or {N, K} in Gemm if transB is true const auto K = transB ? b1 : b0; const auto N = transB ? b0 : b1; - + // we already checked it and dtype must be existed. #if defined(COREML_ENABLE_MLPROGRAM) + auto input_dtype = a.TypeAsProto()->tensor_type().elem_type(); if (model_builder.CreateMLProgram()) { using namespace CoreML::Specification::MILSpec; @@ -136,13 +138,19 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N if (transB) { AddOperationInput(*gemm_op, "weight", b.Name()); } else { - // transpose from {K, N} to {N, K} - std::vector weight_nk; std::vector weight_nk_shape = {N, K}; - ORT_RETURN_IF_ERROR(GetTensorFloatDataTransposed(*b_initializer, weight_nk)); - - AddOperationInput(*gemm_op, "weight", - model_builder.AddConstant(gemm_op->type(), b.Name() + "_t", weight_nk, weight_nk_shape)); + // transpose from {K, N} to {N, K} + if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + std::vector weight_nk; + ORT_RETURN_IF_ERROR(GetTensorDataTransposed(*b_initializer, weight_nk)); + AddOperationInput(*gemm_op, "weight", + model_builder.AddConstant(gemm_op->type(), b.Name() + "_t", weight_nk, weight_nk_shape)); + } else { // TensorProto_DataType_FLOAT16 + std::vector weight_nk; + ORT_RETURN_IF_ERROR(GetTensorDataTransposed(*b_initializer, weight_nk)); + AddOperationInput(*gemm_op, "weight", + model_builder.AddConstant(gemm_op->type(), b.Name() + "_t", weight_nk, weight_nk_shape)); + } } if (input_defs.size() == 3) { @@ -155,15 +163,28 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N AddOperationInput(*gemm_op, "bias", bias_arg.Name()); } else { Initializer unpacked_tensor(bias); - auto bias_data = unpacked_tensor.DataAsSpan(); std::string_view bias_data_name; - if (bias_data.size() == 1) { - // expand scalar to N - std::vector expanded_bias_data(N, bias_data[0]); - bias_data_name = model_builder.AddConstant(gemm_op->type(), "bias", expanded_bias_data); - } else { - // can use data as-is but need to adjust shape (inferred by AddConstant as {bias_data.size()}) - bias_data_name = model_builder.AddConstant(gemm_op->type(), "bias", bias_data); + + if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + auto bias_data = unpacked_tensor.DataAsSpan(); + if (bias_data.size() == 1) { + // expand scalar to N + std::vector expanded_bias_data(N, bias_data[0]); + bias_data_name = model_builder.AddConstant(gemm_op->type(), "bias", expanded_bias_data); + } else { + // can use data as-is but need to adjust shape (inferred by AddConstant as {bias_data.size()}) + bias_data_name = model_builder.AddConstant(gemm_op->type(), "bias", bias_data); + } + } else { // TensorProto_DataType_FLOAT16 + auto bias_data = unpacked_tensor.DataAsSpan(); + if (bias_data.size() == 1) { + // expand scalar to N + std::vector expanded_bias_data(N, bias_data[0]); + bias_data_name = model_builder.AddConstant(gemm_op->type(), "bias", expanded_bias_data); + } else { + // can use data as-is but need to adjust shape (inferred by AddConstant as {bias_data.size()}) + bias_data_name = model_builder.AddConstant(gemm_op->type(), "bias", bias_data); + } } AddOperationInput(*gemm_op, "bias", bias_data_name); @@ -202,7 +223,7 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_inner_product->mutable_weights(), *b_initializer)); } else { std::vector b_transposed; - ORT_RETURN_IF_ERROR(GetTensorFloatDataTransposed(*b_initializer, b_transposed)); + ORT_RETURN_IF_ERROR(GetTensorDataTransposed(*b_initializer, b_transposed)); CreateCoreMLWeight(*coreml_inner_product->mutable_weights(), b_transposed); } diff --git a/onnxruntime/core/providers/coreml/builders/impl/gridsample_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/gridsample_op_builder.cc index 9caec290ea5a2..6dcf14c16f111 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/gridsample_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/gridsample_op_builder.cc @@ -49,6 +49,9 @@ Status GridSampleOpBuilder::AddToModelBuilderImpl([[maybe_unused]] ModelBuilder& const auto input_defs = node.InputDefs(); const auto output_defs = node.OutputDefs(); + // we already checked it and dtype must be existed. + auto input_dtype = input_defs[0]->TypeAsProto()->tensor_type().elem_type(); + NodeAttrHelper helper(node); std::string mode{GetMode(helper)}; // need a std::string for use in AddScalarConstant std::string padding_mode = helper.Get("padding_mode", "zeros"); @@ -65,7 +68,11 @@ Status GridSampleOpBuilder::AddToModelBuilderImpl([[maybe_unused]] ModelBuilder& AddOperationInput(*op, "coordinates", input_defs[1]->Name()); AddOperationInput(*op, "sampling_mode", model_builder.AddScalarConstant(op->type(), "sampling_mode", mode)); AddOperationInput(*op, "padding_mode", model_builder.AddScalarConstant(op->type(), "padding_mode", padding_mode)); - AddOperationInput(*op, "padding_value", model_builder.AddScalarConstant(op->type(), "padding_value", 0.0f)); + if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + AddOperationInput(*op, "padding_value", model_builder.AddScalarConstant(op->type(), "padding_value", 0.0f)); + } else { + AddOperationInput(*op, "padding_value", model_builder.AddScalarConstant(op->type(), "padding_value", MLFloat16(0.0f))); + } AddOperationInput(*op, "coordinates_mode", model_builder.AddScalarConstant(op->type(), "coordinates_mode", coordinates_mode)); AddOperationInput(*op, "align_corners", model_builder.AddScalarConstant(op->type(), "align_corners", align_corners)); diff --git a/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc index 51fc3f2c11c73..6b3fe75fa592d 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc @@ -144,7 +144,7 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const } } - // Only int32 and float are supported by CoreML slice_by_index. + // Int32, float and float16 are supported by CoreML slice_by_index. // We convert any int64 model input to int32 when running the CoreML model for the partition. // Any other integer data created at runtime is the output from CoreML operations, and should int32 not int64. // Based on that, we assume that the actual input when running will be int32, so we override the output data @@ -214,18 +214,29 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const return Status::OK(); } -bool SliceOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& /*input_params*/, +bool SliceOpBuilder::HasSupportedInputsImpl(const Node& node, + [[maybe_unused]] const OpBuilderInputParams& input_params, const logging::Logger& logger) const { int32_t input_type; if (!GetType(*node.InputDefs()[0], input_type, logger)) { return false; } - if (input_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT && - input_type != ONNX_NAMESPACE::TensorProto_DataType_INT64) { - LOGS(logger, VERBOSE) << "[" << node.OpType() << "] Input type: [" << input_type << "] is not supported"; - return false; - } +#ifdef COREML_ENABLE_MLPROGRAM + // The [Doc](https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.tensor_transformation.slice_by_index) + // says ML Program slice_by_index supports fp16 in CoreML 5 (iOS 15). + // It's incorrect and CoreML 6+ (iOS16, CoreML spec version >= 7) is required otherwise only float is supported. + // CoreML 5:https://github.com/apple/coremltools/blob/89d058ffdcb0b39a03031782d8a448b6889ac425/coremltools/converters/mil/mil/ops/defs/tensor_transformation.py#L515 + // CoreML 6:https://github.com/apple/coremltools/blob/c3ea4cf56fef1176417246c1b85363417f3e713d/coremltools/converters/mil/mil/ops/defs/iOS15/tensor_transformation.py#L495 + if (input_params.create_mlprogram && input_params.coreml_version >= 6 && + input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + } else +#endif // nolint + if (input_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT && + input_type != ONNX_NAMESPACE::TensorProto_DataType_INT64) { + LOGS(logger, VERBOSE) << "[" << node.OpType() << "] Input type: [" << input_type << "] is not supported"; + return false; + } return true; } diff --git a/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc index 3403378d59114..335ca737081b2 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc @@ -3,6 +3,7 @@ #include "core/providers/common.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" #include "core/providers/coreml/builders/model_builder.h" @@ -14,6 +15,7 @@ namespace coreml { class UnaryOpBuilder : public BaseOpBuilder { Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; + bool SupportsMLProgram() const override { return true; } }; Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, @@ -21,21 +23,54 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const const auto& op_type(node.OpType()); const auto& input_defs(node.InputDefs()); - std::unique_ptr layer = model_builder.CreateNNLayer(node); +#if defined(COREML_ENABLE_MLPROGRAM) + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; - if (op_type == "Sqrt") { - layer->mutable_unary()->set_type(COREML_SPEC::UnaryFunctionLayerParams::SQRT); - } else if (op_type == "Reciprocal") { - layer->mutable_unary()->set_type(COREML_SPEC::UnaryFunctionLayerParams::INVERSE); - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "UnaryOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); - } + std::string_view coreml_op_type; + if (op_type == "Sqrt") { + coreml_op_type = "sqrt"; + } else if (op_type == "Reciprocal") { + coreml_op_type = "inverse"; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "UnaryOpBuilder::AddToModelBuilderImpl, unexpected op: ", op_type); + } + + std::unique_ptr op = model_builder.CreateOperation(node, coreml_op_type); + AddOperationInput(*op, "x", input_defs[0]->Name()); + if (op_type == "Reciprocal") { + float epsilon = 1e-4; // epsilon: const T (Optional, default=1e-4) + auto dtype = node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + if (dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + AddOperationInput(*op, "epsilon", model_builder.AddScalarConstant(op->type(), "epsilon", epsilon)); + } else if (dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + AddOperationInput(*op, "epsilon", model_builder.AddScalarConstant(op->type(), "epsilon", MLFloat16(epsilon))); + } + } + + AddOperationOutput(*op, *node.OutputDefs()[0]); - *layer->mutable_input()->Add() = input_defs[0]->Name(); - *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); + model_builder.AddOperation(std::move(op)); + } else // NOLINT +#endif // defined (COREML_ENABLE_MLPROGRAM) + { + std::unique_ptr layer = model_builder.CreateNNLayer(node); - model_builder.AddLayer(std::move(layer)); + if (op_type == "Sqrt") { + layer->mutable_unary()->set_type(COREML_SPEC::UnaryFunctionLayerParams::SQRT); + } else if (op_type == "Reciprocal") { + layer->mutable_unary()->set_type(COREML_SPEC::UnaryFunctionLayerParams::INVERSE); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "UnaryOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); + } + + *layer->mutable_input()->Add() = input_defs[0]->Name(); + *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); + + model_builder.AddLayer(std::move(layer)); + } return Status::OK(); } diff --git a/onnxruntime/core/providers/coreml/builders/model_builder.cc b/onnxruntime/core/providers/coreml/builders/model_builder.cc index 9668bfcd09adf..50faebf06875d 100644 --- a/onnxruntime/core/providers/coreml/builders/model_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/model_builder.cc @@ -639,6 +639,14 @@ std::string_view ModelBuilder::AddConstantImpl(std::string_view op_type, std::st return AddTensorValueAsConstantOperation(op_type, value_type, std::move(input_value)); } +template <> +std::string_view ModelBuilder::AddConstantImpl(std::string_view op_type, std::string_view value_type, + gsl::span value, + std::optional> shape) { + auto input_value = CreateTensorValue(value, shape); + return AddTensorValueAsConstantOperation(op_type, value_type, std::move(input_value)); +} + template <> std::string_view ModelBuilder::AddConstantImpl(std::string_view op_type, std::string_view value_type, gsl::span value, @@ -811,6 +819,9 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: multi_array->set_datatype(ArrayFeatureType::FLOAT32); break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + multi_array->set_datatype(ArrayFeatureType::FLOAT16); + break; case ONNX_NAMESPACE::TensorProto_DataType_INT32: multi_array->set_datatype(ArrayFeatureType::INT32); break; diff --git a/onnxruntime/core/providers/coreml/builders/model_builder.h b/onnxruntime/core/providers/coreml/builders/model_builder.h index bb791fb902908..b3dfec29872a2 100644 --- a/onnxruntime/core/providers/coreml/builders/model_builder.h +++ b/onnxruntime/core/providers/coreml/builders/model_builder.h @@ -107,11 +107,12 @@ class ModelBuilder { std::string_view AddConstant(std::string_view op_type, std::string_view value_type, gsl::span value, std::optional> shape = std::nullopt) { static_assert(std::is_same_v || + std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v, // add specialization in AddConstantImpl for new types if needed - "AddConstant currently supports float, int64_t, std::string and bool."); + "AddConstant currently supports float, MLFloat16, int64_t, std::string and bool."); return AddConstantImpl(op_type, value_type, value, shape); } diff --git a/onnxruntime/core/providers/coreml/model/model.mm b/onnxruntime/core/providers/coreml/model/model.mm index 68460ff7c9b31..1401cbe95fd56 100644 --- a/onnxruntime/core/providers/coreml/model/model.mm +++ b/onnxruntime/core/providers/coreml/model/model.mm @@ -120,6 +120,10 @@ Status CreateInputFeatureProvider(const std::unordered_map +void StridedCopy(const T* src_buffer, T* dst_buffer, size_t block_size, + size_t num_blocks, size_t src_stride, size_t dst_stride) { + for (size_t idx = 0; idx < num_blocks; ++idx) { + std::copy_n(src_buffer, block_size, dst_buffer); + src_buffer += src_stride; + dst_buffer += dst_stride; + } +} + Status CopyMLMultiArrayBuffer(const void* mlmultiarray_buffer, void* tensor_buffer, const MLMultiArray* array, const int64_t num_blocks, const int64_t block_size, const int64_t stride, @@ -196,25 +210,21 @@ Status CopyMLMultiArrayBuffer(const void* mlmultiarray_buffer, void* tensor_buff case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: { const auto* src_buffer = static_cast(mlmultiarray_buffer); auto* dst_buffer = static_cast(tensor_buffer); - const auto block_byte_size = block_size * sizeof(float); + StridedCopy(src_buffer, dst_buffer, block_size, num_blocks, stride, block_size); + + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: { + const auto* src_buffer = static_cast(mlmultiarray_buffer); + auto* dst_buffer = static_cast(tensor_buffer); + StridedCopy(src_buffer, dst_buffer, block_size, num_blocks, stride, block_size); - for (int64_t idx = 0; idx < num_blocks; ++idx) { - memcpy(dst_buffer, src_buffer, block_byte_size); - src_buffer += stride; - dst_buffer += block_size; - } break; } case ONNX_NAMESPACE::TensorProto_DataType_INT32: { const auto* src_buffer = static_cast(mlmultiarray_buffer); auto* dst_buffer = static_cast(tensor_buffer); - const auto block_byte_size = block_size * sizeof(int32_t); - - for (int64_t idx = 0; idx < num_blocks; ++idx) { - memcpy(dst_buffer, src_buffer, block_byte_size); - src_buffer += stride; - dst_buffer += block_size; - } + StridedCopy(src_buffer, dst_buffer, block_size, num_blocks, stride, block_size); break; } diff --git a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc index d2e883331acd4..724118d7419d2 100644 --- a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc +++ b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc @@ -125,7 +125,7 @@ TEST_F(ActivationOpTest, Relu) { {}, {}, /*is_tensorrt_supported=*/false, /*opset_version= */ 14); -#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM) TestActivationOp( "Relu", input_values_fp16, @@ -139,7 +139,7 @@ TEST_F(ActivationOpTest, Relu) { #endif // MLAS_F16VEC_INTRINSICS_SUPPORTED } -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(COREML_ENABLE_MLPROGRAM) TEST_F(ActivationOpTest, Sigmoid_fp16) { #ifdef USE_CUDA int min_cuda_architecture = 530; @@ -413,7 +413,7 @@ TEST_F(ActivationOpTest, LeakyRelu) { {{"alpha", alpha}}, {}); } -#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM) TEST_F(ActivationOpTest, LeakyRelu_fp16) { OpTester test("LeakyRelu", 11); float alpha = 0.01f; // oneDNN set alpha equal to 0.01 diff --git a/onnxruntime/test/providers/cpu/activation/activation_op_test.h b/onnxruntime/test/providers/cpu/activation/activation_op_test.h index 409409f56c51c..8ca0f6d845a09 100644 --- a/onnxruntime/test/providers/cpu/activation/activation_op_test.h +++ b/onnxruntime/test/providers/cpu/activation/activation_op_test.h @@ -90,7 +90,6 @@ class ActivationOpTest : public ::testing::Test { DBL_MAX, -DBL_MAX, std::numeric_limits::infinity()}}; // max, -max, inf std::vector> input_values_int8{{-1, -5, 0, 1, 5, 100, -100, // normal input values for activation std::numeric_limits::min(), std::numeric_limits::max()}}; // min, max -#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED std::vector> input_values_fp16{{MLFloat16(-1.0f), MLFloat16(-5.f), MLFloat16(), @@ -100,7 +99,6 @@ class ActivationOpTest : public ::testing::Test { MLFloat16(-100.f), MLFloat16(65504.f), MLFloat16(-65504.f)}}; -#endif // MLAS_F16VEC_INTRINSICS_SUPPORTED void SetUp() override { float low = -1.0f, high = 1.0f; diff --git a/onnxruntime/test/providers/cpu/math/clip_test.cc b/onnxruntime/test/providers/cpu/math/clip_test.cc index 7a31562a6883e..c1452ab686279 100644 --- a/onnxruntime/test/providers/cpu/math/clip_test.cc +++ b/onnxruntime/test/providers/cpu/math/clip_test.cc @@ -120,21 +120,25 @@ TEST(MathOpTest, Clip_Default_uint64) { } TEST(MathOpTest, Clip_MLFloat16) { - OpTester test("Clip", 12); - - std::vector dims{3, 3}; - test.AddInput("X", dims, - {MLFloat16(-1.0f), MLFloat16(-2.0f), MLFloat16(-3.0f), - MLFloat16(-4.0f), MLFloat16(0.0f), MLFloat16(2.0f), - MLFloat16(4.0f), MLFloat16(6.0f), MLFloat16(8.0f)}); - test.AddInput("min", {}, {MLFloat16(0.0f)}); - test.AddInput("max", {}, {MLFloat16(6.0f)}); - test.AddOutput("Y", dims, - {MLFloat16(0.0f), MLFloat16(0.0f), MLFloat16(0.0f), - MLFloat16(0.0f), MLFloat16(0.0f), MLFloat16(2.0f), - MLFloat16(4.0f), MLFloat16(6.0f), MLFloat16(6.0f)}); + auto run_test = [](bool min_max_are_initializer) { + OpTester test("Clip", 12); - test.Run(); + std::vector dims{3, 3}; + test.AddInput("X", dims, + {MLFloat16(-1.0f), MLFloat16(-2.0f), MLFloat16(-3.0f), + MLFloat16(-4.0f), MLFloat16(0.0f), MLFloat16(2.0f), + MLFloat16(4.0f), MLFloat16(6.0f), MLFloat16(8.0f)}); + test.AddInput("min", {}, {MLFloat16(0.0f)}, min_max_are_initializer); + test.AddInput("max", {}, {MLFloat16(6.0f)}, min_max_are_initializer); + test.AddOutput("Y", dims, + {MLFloat16(0.0f), MLFloat16(0.0f), MLFloat16(0.0f), + MLFloat16(0.0f), MLFloat16(0.0f), MLFloat16(2.0f), + MLFloat16(4.0f), MLFloat16(6.0f), MLFloat16(6.0f)}); + + test.Run(); + }; + run_test(true); // coreml requires constant max/min + run_test(false); } TEST(MathOpTest, Clip_MLFloat16_NoMin_NoMax) { diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index 507ed8e91a728..b2e9034653746 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -22,40 +22,93 @@ std::vector MakeMLFloat16(const std::initializer_list& input) return output; } -#if defined(USE_CUDA) || defined(USE_ROCM) -void TestFloat16(const char* op_name, const std::vector& lhs_dim, - const std::initializer_list& lhs_values, const std::vector& rhs_dim, - const std::initializer_list& rhs_values, const std::vector& out_dim, - const std::initializer_list& out_values) { +void TestBinaryFloat16(const char* op_name, + const std::vector& lhs_dim, + const std::initializer_list& lhs_values, + const std::vector& rhs_dim, + const std::initializer_list& rhs_values, + const std::vector& out_dim, + const std::initializer_list& out_values, + bool enable_bf16 = true) { + { + std::vector> execution_providers; +#ifdef COREML_ENABLE_MLPROGRAM + execution_providers.push_back(DefaultCoreMLExecutionProvider(true)); +#elif USE_CUDA + execution_providers.push_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + execution_providers.push_back(DefaultRocmExecutionProvider()); +#endif + if (execution_providers.size() > 0) { + OpTester tester(op_name, 14); + tester.AddInput("A", lhs_dim, MakeMLFloat16(lhs_values)); + tester.AddInput("B", rhs_dim, MakeMLFloat16(rhs_values)); + tester.AddOutput("C", out_dim, MakeMLFloat16(out_values)); + + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + } { - OpTester tester(op_name, 14); - tester.AddInput("A", lhs_dim, MakeMLFloat16(lhs_values)); - tester.AddInput("B", rhs_dim, MakeMLFloat16(rhs_values)); - tester.AddOutput("C", out_dim, MakeMLFloat16(out_values)); std::vector> execution_providers; #ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); #elif USE_ROCM execution_providers.push_back(DefaultRocmExecutionProvider()); #endif - tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + + if (enable_bf16 && execution_providers.size() > 0) { + OpTester tester(op_name, 14); + tester.AddInput("A", lhs_dim, MakeBFloat16(lhs_values)); + tester.AddInput("B", rhs_dim, MakeBFloat16(rhs_values)); + tester.AddOutput("C", out_dim, MakeBFloat16(out_values)); + + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + } +} + +void TestUnaryFloat16(const char* op_name, + const std::vector& lhs_dim, + const std::initializer_list& lhs_values, + const std::vector& out_dim, + const std::initializer_list& out_values, + int opset = 14, + bool run_bf16 = true) { + { + std::vector> execution_providers; +#ifdef COREML_ENABLE_MLPROGRAM + execution_providers.push_back(DefaultCoreMLExecutionProvider(true)); +#elif USE_CUDA + execution_providers.push_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + execution_providers.push_back(DefaultRocmExecutionProvider()); +#endif + if (execution_providers.size() > 0) { + OpTester tester(op_name, opset); + tester.AddInput("A", lhs_dim, MakeMLFloat16(lhs_values)); + tester.AddOutput("C", out_dim, MakeMLFloat16(out_values)); + + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } } { - OpTester tester(op_name, 14); - tester.AddInput("A", lhs_dim, MakeBFloat16(lhs_values)); - tester.AddInput("B", rhs_dim, MakeBFloat16(rhs_values)); - tester.AddOutput("C", out_dim, MakeBFloat16(out_values)); std::vector> execution_providers; #ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); #elif USE_ROCM execution_providers.push_back(DefaultRocmExecutionProvider()); #endif - tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + + if (run_bf16 && execution_providers.size() > 0) { + OpTester tester(op_name, opset); + tester.AddInput("A", lhs_dim, MakeBFloat16(lhs_values)); + tester.AddOutput("C", out_dim, MakeBFloat16(out_values)); + + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } } } -#endif void TestBFloat16(const char* op_name, const std::vector& lhs_dim, const std::initializer_list& lhs_values, const std::vector& rhs_dim, @@ -163,9 +216,7 @@ TEST(MathOpTest, Add_float) { test.Run(); #endif -#if defined(USE_CUDA) || defined(USE_ROCM) - TestFloat16("Add", dims, lhs_values, dims, rhs_values, dims, out_values); -#endif + TestBinaryFloat16("Add", dims, lhs_values, dims, rhs_values, dims, out_values); #if defined(USE_DNNL) TestBFloat16("Add", dims, lhs_values, dims, rhs_values, dims, out_values); @@ -202,9 +253,7 @@ TEST(MathOpTest, Add_Broadcast_Axis) { test.AddOutput("C", dims, out_values); test.Run(OpTester::ExpectResult::kExpectSuccess, ""); -#if defined(USE_CUDA) || defined(USE_ROCM) - TestFloat16("Add", dims, lhs_values, {3, 1}, rhs_values, dims, out_values); -#endif + TestBinaryFloat16("Add", dims, lhs_values, {3, 1}, rhs_values, dims, out_values); #if defined(USE_DNNL) TestBFloat16("Add", dims, lhs_values, {3, 1}, rhs_values, dims, out_values); @@ -228,9 +277,7 @@ TEST(MathOpTest, Add_Broadcast_MultidirectionalAB) { {kTensorrtExecutionProvider}); // TensorRT: got C with shape [3, 1] #endif -#if defined(USE_CUDA) || defined(USE_ROCM) - TestFloat16("Add", {3, 1}, lhs_values, {3}, rhs_values, {3, 3}, out_values); -#endif + TestBinaryFloat16("Add", {3, 1}, lhs_values, {3}, rhs_values, {3, 3}, out_values); #if defined(USE_DNNL) TestBFloat16("Add", {3, 1}, lhs_values, {3}, rhs_values, {3, 3}, out_values); @@ -254,9 +301,7 @@ TEST(MathOpTest, Add_Broadcast_MultidirectionalBA) { {kTensorrtExecutionProvider}); // TensorRT: got C with shape [3, 1] #endif -#if defined(USE_CUDA) || defined(USE_ROCM) - TestFloat16("Add", {3}, lhs_values, {3, 1}, rhs_values, {3, 3}, out_values); -#endif + TestBinaryFloat16("Add", {3}, lhs_values, {3, 1}, rhs_values, {3, 3}, out_values); #if defined(USE_DNNL) TestBFloat16("Add", {3}, lhs_values, {3, 1}, rhs_values, {3, 3}, out_values); @@ -527,9 +572,7 @@ TEST(MathOpTest, Sub) { test.AddOutput("C", dims, out_values); test.Run(); -#if defined(USE_CUDA) || defined(USE_ROCM) - TestFloat16("Sub", dims, lhs_values, dims, rhs_values, dims, out_values); -#endif + TestBinaryFloat16("Sub", dims, lhs_values, dims, rhs_values, dims, out_values); #if defined(USE_DNNL) TestBFloat16("Sub", dims, lhs_values, dims, rhs_values, dims, out_values); @@ -584,9 +627,7 @@ TEST(MathOpTest, Mul) { test.Run(); -#if defined(USE_CUDA) || defined(USE_ROCM) - TestFloat16("Mul", dims, lhs_values, dims, rhs_values, dims, out_values); -#endif + TestBinaryFloat16("Mul", dims, lhs_values, dims, rhs_values, dims, out_values); #if defined(USE_DNNL) TestBFloat16("Mul", dims, lhs_values, dims, rhs_values, dims, out_values); @@ -622,9 +663,7 @@ TEST(MathOpTest, Div) { test.AddOutput("C", dims, out_values); test.Run(); -#if defined(USE_CUDA) || defined(USE_ROCM) - TestFloat16("Div", dims, lhs_values, dims, rhs_values, dims, out_values); -#endif + TestBinaryFloat16("Div", dims, lhs_values, dims, rhs_values, dims, out_values); #if defined(USE_DNNL) TestBFloat16("Div", dims, lhs_values, dims, rhs_values, dims, out_values); @@ -772,13 +811,12 @@ TEST(MathOpTest, Ceil_double) { TEST(MathOpTest, Reciprocal) { OpTester test("Reciprocal"); std::vector dims{2, 2}; - test.AddInput("X", dims, - {1.0f, 2.0f, - -1.0f, -2.0f}); - test.AddOutput("Y", dims, - {1.0f, 0.5f, - -1.0f, -0.5f}); + std::initializer_list inputs = {1.0f, 2.0f, -1.0f, -2.0f}; + std::initializer_list outputs = {1.0f, 0.5f, -1.0f, -0.5f}; + test.AddInput("X", dims, inputs); + test.AddOutput("Y", dims, outputs); test.Run(); + TestUnaryFloat16("Reciprocal", dims, inputs, dims, outputs, 12, false); } TEST(MathOpTest, Reciprocal_double) { @@ -795,14 +833,13 @@ TEST(MathOpTest, Reciprocal_double) { TEST(MathOpTest, Sqrt_Float) { OpTester test("Sqrt"); + std::initializer_list inputs = {1.0f, 4.0f, 0.0f, 9.0f}; + std::initializer_list outputs = {1.0f, 2.0f, 0.0f, 3.0f}; std::vector dims{2, 2}; - test.AddInput("X", dims, - {1.0f, 4.0f, - 0.0f, 9.0f}); - test.AddOutput("Y", dims, - {1.0f, 2.0f, - 0.0f, 3.0f}); + test.AddInput("X", dims, inputs); + test.AddOutput("Y", dims, outputs); test.Run(); + TestUnaryFloat16("Sqrt", dims, inputs, dims, outputs); } #if defined(USE_DNNL) || defined(USE_CUDA) @@ -1056,24 +1093,13 @@ TEST(MathOpTest, Pow_double_int64) { test.Run(); } -#if defined(USE_CUDA) || defined(USE_ROCM) TEST(MathOpTest, Pow_float16_float16) { - OpTester test("Pow", 12); std::vector dims{4}; - - test.AddInput("X", dims, MakeMLFloat16({2.0f, 2.0f, std::sqrt(2.0f), 1.0f})); - test.AddInput("Y", dims, MakeMLFloat16({0.0f, 8.0f, 2.0f, 9.0f})); - test.AddOutput("Z", dims, MakeMLFloat16({1.0f, 256.0f, 2.0f, 1.0f})); - - std::vector> execution_providers; -#ifdef USE_CUDA - execution_providers.push_back(DefaultCudaExecutionProvider()); -#elif USE_ROCM - execution_providers.push_back(DefaultRocmExecutionProvider()); -#endif - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + TestBinaryFloat16("Pow", dims, {2.0f, 2.0f, std::sqrt(2.0f), 1.0f}, dims, {0.0f, 8.0f, 2.0f, 9.0f}, + dims, {1.0f, 256.0f, 2.0f, 1.0f}, false); } +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(COREML_ENABLE_MLPROGRAM) TEST(MathOpTest, Pow_float_float16) { OpTester test("Pow", 12); std::vector dims{4}; @@ -1087,6 +1113,8 @@ TEST(MathOpTest, Pow_float_float16) { execution_providers.push_back(DefaultCudaExecutionProvider()); #elif USE_ROCM execution_providers.push_back(DefaultRocmExecutionProvider()); +#elif COREML_ENABLE_MLPROGRAM + execution_providers.push_back(DefaultCoreMLExecutionProvider(true)); #endif test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } @@ -3797,5 +3825,6 @@ TEST(MathOpTest, BitwiseNot_uint8) { test.AddOutput("Y", dims, {254, 251, 250, 252}); test.Run(); } + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/math/gemm_test.cc b/onnxruntime/test/providers/cpu/math/gemm_test.cc index 625ff29d4ccf9..66408e6adfbc5 100644 --- a/onnxruntime/test/providers/cpu/math/gemm_test.cc +++ b/onnxruntime/test/providers/cpu/math/gemm_test.cc @@ -25,7 +25,7 @@ const constexpr auto run_with_tunable_op = &run_options; } // namespace -// Only CUDA and ROCM kernel has float 16 support +// Only CUDA, ROCM and CoreML kernels have float 16 support TEST(GemmOpTest, GemmNoTrans_f16) { #ifdef USE_CUDA int min_cuda_architecture = 530; @@ -34,36 +34,142 @@ TEST(GemmOpTest, GemmNoTrans_f16) { return; } #endif - OpTester test("Gemm", 13); - test.AddAttribute("transA", (int64_t)0); - test.AddAttribute("transB", (int64_t)0); - test.AddAttribute("alpha", 1.0f); - test.AddAttribute("beta", 1.0f); + std::vector A{1.0f, 2.0f, 3.0f, 4.0f, + -1.0f, -2.0f, -3.0f, -4.0f}; + std::vector B = {0.5f, 2.1f, 1.2f, -0.3f, + -1.2f, 0.2f, 1.0f, -2.1f, + 1.3f, 4.1f, 1.3f, -8.1f}; + std::vector C = {0.5f, 2.1f, 1.2f, + -0.3f, -1.2f, 0.2f}; + + std::vector f_A(8); + std::vector f_B(12); + ConvertFloatToMLFloat16(A.data(), f_A.data(), 8); + ConvertFloatToMLFloat16(B.data(), f_B.data(), 12); + + { + // bias has same shape as output + std::vector f_Y(6); + std::vector Y{19.8f, 0.7f, -25.7f, + -19.6f, 0.2f, 27.1f}; + ConvertFloatToMLFloat16(Y.data(), f_Y.data(), 6); + + std::vector f_C(6); + ConvertFloatToMLFloat16(C.data(), f_C.data(), 6); + + OpTester test("Gemm", 13); + + test.AddAttribute("transA", (int64_t)0); + test.AddAttribute("transB", (int64_t)0); + test.AddAttribute("alpha", 1.0f); + test.AddAttribute("beta", 1.0f); + test.AddInput("A", {2, 4}, f_A); + test.AddInput("B", {4, 3}, f_B); + test.AddInput("C", {2, 3}, f_C); + test.AddOutput("Y", {2, 3}, f_Y); + // we used float data with decimal instead of only integer, increase Tolerance to make test pass + test.SetOutputTolerance(0.005f); + test.ConfigExcludeEps({kTensorrtExecutionProvider}) // TensorRT: fp16 is not supported + .Config(run_with_tunable_op) + .RunWithConfig(); + } + { + // bias has shape {1, output_features} + std::vector f_Y(6); + std::vector Y{19.8f, 0.7f, -25.7f, + -18.8f, 3.5f, 28.1f}; + ConvertFloatToMLFloat16(Y.data(), f_Y.data(), 6); + + std::vector f_C(3); + ConvertFloatToMLFloat16(C.data(), f_C.data(), 3); + // CoreML program require B/C are constant + OpTester test("Gemm", 13); + + test.AddAttribute("transA", (int64_t)0); + test.AddAttribute("transB", (int64_t)0); + test.AddAttribute("alpha", 1.0f); + test.AddAttribute("beta", 1.0f); + test.AddInput("A", {2, 4}, f_A); + test.AddInput("B", {4, 3}, f_B, true); + test.AddInput("C", {3}, f_C, true); + test.AddOutput("Y", {2, 3}, f_Y); + test.SetOutputTolerance(0.005f); + test.ConfigExcludeEps({kTensorrtExecutionProvider}) // TensorRT: fp16 is not supported + .Config(run_with_tunable_op) + .RunWithConfig(); + } + { + // bias is a scalar + std::vector f_Y(6); + std::vector Y{19.8f, -0.9f, -26.4f, + -18.8f, 1.9f, 27.4f}; + ConvertFloatToMLFloat16(Y.data(), f_Y.data(), 6); + + std::vector f_C(1); + ConvertFloatToMLFloat16(C.data(), f_C.data(), 1); + OpTester test("Gemm", 13); + + test.AddAttribute("transA", (int64_t)0); + test.AddAttribute("transB", (int64_t)0); + test.AddAttribute("alpha", 1.0f); + test.AddAttribute("beta", 1.0f); + test.AddInput("A", {2, 4}, f_A); + test.AddInput("B", {4, 3}, f_B, true); + test.AddInput("C", {1}, f_C, true); + test.AddOutput("Y", {2, 3}, f_Y); + test.SetOutputTolerance(0.005f); + test.ConfigExcludeEps({kTensorrtExecutionProvider}) // TensorRT: fp16 is not supported + .Config(run_with_tunable_op) + .RunWithConfig(); + } +} + +// Only CUDA, ROCM and CoreML kernels have float 16 support +TEST(GemmOpTest, GemmTransB_f16) { +#ifdef USE_CUDA + int min_cuda_architecture = 530; + if (!HasCudaEnvironment(min_cuda_architecture)) { + LOGS_DEFAULT(WARNING) << "Hardware NOT support FP16"; + return; + } +#endif std::vector A{1.0f, 2.0f, 3.0f, 4.0f, -1.0f, -2.0f, -3.0f, -4.0f}; - std::vector B(12, 1.0f); - std::vector C(6, 1.0f); - std::vector Y{11.0f, 11.0f, 11.0f, - -9.0f, -9.0f, -9.0f}; + std::vector B = {0.5f, 2.1f, 1.2f, -0.3f, + -1.2f, 0.2f, 1.0f, -2.1f, + 1.3f, 4.1f, 1.3f, -8.1f}; + std::vector C = {0.5f, 2.1f, 1.2f, + -0.3f, -1.2f, 0.2f}; std::vector f_A(8); std::vector f_B(12); - std::vector f_C(6); - std::vector f_Y(6); ConvertFloatToMLFloat16(A.data(), f_A.data(), 8); ConvertFloatToMLFloat16(B.data(), f_B.data(), 12); - ConvertFloatToMLFloat16(C.data(), f_C.data(), 6); - ConvertFloatToMLFloat16(Y.data(), f_Y.data(), 6); - - test.AddInput("A", {2, 4}, f_A); - test.AddInput("B", {4, 3}, f_B); - test.AddInput("C", {2, 3}, f_C); - test.AddOutput("Y", {2, 3}, f_Y); - test.ConfigExcludeEps({kTensorrtExecutionProvider}) // TensorRT: fp16 is not supported - .Config(run_with_tunable_op) - .RunWithConfig(); + { + // bias is a scalar and transB is True + std::vector f_Y(6); + std::vector Y{7.6f, -5.7f, -18.5f, -6.6f, 6.7f, 19.5f}; + ConvertFloatToMLFloat16(Y.data(), f_Y.data(), 6); + + std::vector f_C(1); + ConvertFloatToMLFloat16(C.data(), f_C.data(), 1); + OpTester test("Gemm", 13); + + test.AddAttribute("transA", (int64_t)0); + test.AddAttribute("transB", (int64_t)1); + test.AddAttribute("alpha", 1.0f); + test.AddAttribute("beta", 1.0f); + test.AddInput("A", {2, 4}, f_A); + test.AddInput("B", {3, 4}, f_B, true); + test.AddInput("C", {1}, f_C, true); + test.AddOutput("Y", {2, 3}, f_Y); + test.SetOutputTolerance(0.005f); + test.ConfigExcludeEps({kTensorrtExecutionProvider}) // TensorRT: fp16 is not supported + .Config(run_with_tunable_op) + .RunWithConfig(); + } } #if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DNNL) diff --git a/onnxruntime/test/providers/cpu/math/matmul_test.cc b/onnxruntime/test/providers/cpu/math/matmul_test.cc index 90370560859aa..a7d2281ac19f8 100644 --- a/onnxruntime/test/providers/cpu/math/matmul_test.cc +++ b/onnxruntime/test/providers/cpu/math/matmul_test.cc @@ -246,7 +246,7 @@ TEST(MathOpTest, MatMulZeroKInt32Type) { RunMatMulZeroKTest(); } -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(COREML_ENABLE_MLPROGRAM) TEST(MathOpTest, MatMul_Float16) { #ifdef USE_CUDA int min_cuda_architecture = 530; diff --git a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc index 95b274966fbbb..ce1ac7591ec34 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc @@ -3,7 +3,7 @@ #include "core/mlas/inc/mlas.h" -#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM) #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" @@ -28,6 +28,15 @@ struct ConvOpAndTestAttributes { vector activation_parameters = {}; }; +/* +Please notice that, we have predefined macros in the head of the file +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM) +When we have these two macro defines, this UT will turn into green light and work. + +If attributes.activation is set the NhwcFusedConv contrib op is used. +If you are adding support for a new EP to the test and the EP does not support NhwcFusedConv +please add the EP to the excluded_providers list. +*/ void TestConvFp16Op(const ConvOpAndTestAttributes& attributes, const vector>& inputs, const vector>& input_shapes, @@ -81,11 +90,13 @@ void TestConvFp16Op(const ConvOpAndTestAttributes& attributes, std::unordered_set excluded_providers(attributes.excluded_providers); // Disable TensorRT because weight as input is not supported excluded_providers.insert(kTensorrtExecutionProvider); - // QNN have issue with dynamic weight, auto pad with SAME_UPPER, SAME_LOWER + // QNN has issue with dynamic weight, auto pad with SAME_UPPER, SAME_LOWER if (!weight_is_initializer || attributes.auto_pad == "SAME_UPPER" || attributes.auto_pad == "SAME_LOWER") { excluded_providers.insert(kQnnExecutionProvider); } - + if (!weight_is_initializer || !attributes.activation.empty()) { + excluded_providers.insert(kCoreMLExecutionProvider); + } tester->Run(expect_result, err_str, excluded_providers); } @@ -1147,6 +1158,7 @@ TEST(ConvFp16Test, Pointwise_Relu) { MLFloat16(17.5f), MLFloat16(9.5f)}; TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); + TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); } TEST(ConvFp16Test, Conv2D_HardSigmoid) { @@ -1176,6 +1188,7 @@ TEST(ConvFp16Test, Conv2D_HardSigmoid) { MLFloat16(1.0f), MLFloat16(0.0f), MLFloat16(1.0f), MLFloat16(0.0f)}; TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); + TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); } TEST(ConvFp16Test, Conv2D_Bias_Z_Relu) { @@ -1205,6 +1218,7 @@ TEST(ConvFp16Test, Conv2D_Bias_Z_Relu) { vector Z_shape = {1, 2, 2, 2}; auto expected_vals = {MLFloat16(12.0f), MLFloat16(11.0f), MLFloat16(17.0f), MLFloat16(15.0f), MLFloat16(25.0f), MLFloat16(23.0f), MLFloat16(29.0f), MLFloat16(28.0f)}; TestConvFp16Op(attrs, {X, W, B, Z}, {X_shape, W_shape, B_shape, Z_shape}, expected_vals, Y_shape); + TestConvFp16Op(attrs, {X, W, B, Z}, {X_shape, W_shape, B_shape, Z_shape}, expected_vals, Y_shape, true); } #endif // CONTRIB_OPS diff --git a/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc b/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc index 2bf53ce5b5986..29525f89ef544 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc @@ -22,10 +22,11 @@ struct ConvTransposeOpAttributes { string auto_pad; }; +template void TestConvTransposeOpInitializer(const ConvTransposeOpAttributes& attributes, - const vector>& inputs, + const vector>& inputs, const vector>& input_shapes, - const std::initializer_list& expected_output, + const std::vector& expected_output, const vector& expected_output_shape, bool is_weight_and_bias_initializer = false, OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess, @@ -61,17 +62,18 @@ void TestConvTransposeOpInitializer(const ConvTransposeOpAttributes& attributes, const char* input_names[] = {"X", "W", "B"}; bool is_initializers[] = {false, is_weight_and_bias_initializer, is_weight_and_bias_initializer}; for (size_t i = 0; i < inputs.size(); i++) { - test.AddInput(input_names[i], input_shapes[i], inputs[i], is_initializers[i]); + test.AddInput(input_names[i], input_shapes[i], inputs[i], is_initializers[i]); } - test.AddOutput("Y", expected_output_shape, expected_output); + test.AddOutput("Y", expected_output_shape, expected_output); test.Run(expect_result, err_str, excluded_provider_types); // Disable TensorRT because weight as input is not supported } +template void TestConvTransposeOp(const ConvTransposeOpAttributes& attributes, - const vector>& inputs, + const vector>& inputs, const vector>& input_shapes, - const std::initializer_list& expected_output, + const std::vector& expected_output, const vector& expected_output_shape, OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess, const std::string& err_str = "", @@ -87,6 +89,13 @@ void TestConvTransposeOp(const ConvTransposeOpAttributes& attributes, } // namespace +template +class ConvTransposeTest : public ::testing::Test { +}; + +using ConvTransposeTestTypes = ::testing::Types; +TYPED_TEST_SUITE(ConvTransposeTest, ConvTransposeTestTypes); + TEST(ConvTransposeTest, ConvTranspose_1D) { ConvTransposeOpAttributes attrs{ vector{3}, // kernel_shape @@ -108,13 +117,24 @@ TEST(ConvTransposeTest, ConvTranspose_1D) { 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f}; vector Y_shape = {1, 2, 5}; - auto expected_vals = {18.1f, 40.2f, 66.3f, 48.f, 26.f, - 9.4f, 22.5f, 39.6f, 30.f, 17.f}; + vector expected_vals = {18.1f, 40.2f, 66.3f, 48.f, 26.f, + 9.4f, 22.5f, 39.6f, 30.f, 17.f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } -TEST(ConvTransposeTest, ConvTranspose_2D_outputpadding_strides2) { +template +static std::vector GetTypedArray(std::vector inputs, [[maybe_unused]] T v = T(0.f)) { + if constexpr (std::is_same::value) { + return inputs; + } else { + std::vector inputs_fp16(inputs.size()); + ConvertFloatToMLFloat16(inputs.data(), inputs_fp16.data(), inputs.size()); + return inputs_fp16; + } +} + +TYPED_TEST(ConvTransposeTest, ConvTranspose_2D_outputpadding_strides2) { ConvTransposeOpAttributes attrs = { vector{3, 3}, // kernel_shape vector{1, 1}, // output_padding @@ -137,17 +157,27 @@ TEST(ConvTransposeTest, ConvTranspose_2D_outputpadding_strides2) { 0.04118127f, -0.44696793f, 0.06373066f}; vector Y_shape = {1, 1, 6, 6}; - auto expected_vals = {0.07368518f, -0.08925839f, -0.06627201f, 0.06301362f, 0.03732984f, -0.01919658f, - -0.00628807f, -0.02817563f, -0.01472169f, 0.04392925f, -0.00689478f, -0.01549204f, - 0.07957941f, -0.11459791f, -0.09505399f, 0.07681622f, 0.03604182f, -0.01853423f, - -0.0270785f, -0.00680824f, -0.06650258f, 0.08004665f, 0.07918708f, -0.0724144f, - 0.06256775f, -0.17838378f, -0.18863615f, 0.20064656f, 0.133717f, -0.06876295f, - -0.06398046f, -0.00864975f, 0.19289537f, -0.01490572f, -0.13673618f, 0.01949645f}; - TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); + vector expected_vals = {0.07368518f, -0.08925839f, -0.06627201f, 0.06301362f, 0.03732984f, -0.01919658f, + -0.00628807f, -0.02817563f, -0.01472169f, 0.04392925f, -0.00689478f, -0.01549204f, + 0.07957941f, -0.11459791f, -0.09505399f, 0.07681622f, 0.03604182f, -0.01853423f, + -0.0270785f, -0.00680824f, -0.06650258f, 0.08004665f, 0.07918708f, -0.0724144f, + 0.06256775f, -0.17838378f, -0.18863615f, 0.20064656f, 0.133717f, -0.06876295f, + -0.06398046f, -0.00864975f, 0.19289537f, -0.01490572f, -0.13673618f, 0.01949645f}; + if constexpr (std::is_same::value) { + TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); + } else { + vector X_fp16(X.size()); + ConvertFloatToMLFloat16(X.data(), X_fp16.data(), X.size()); + vector W_fp16(W.size()); + ConvertFloatToMLFloat16(W.data(), W_fp16.data(), W.size()); + std::vector expected_vals_fp16(expected_vals.size()); + ConvertFloatToMLFloat16(expected_vals.data(), expected_vals_fp16.data(), expected_vals.size()); + TestConvTransposeOp(attrs, {X_fp16, W_fp16}, {X_shape, W_shape}, expected_vals_fp16, Y_shape); + } } // 2D input with C > 1 -TEST(ConvTransposeTest, ConvTranspose_2D_C2) { +TYPED_TEST(ConvTransposeTest, ConvTranspose_2D_C2) { ConvTransposeOpAttributes attrs = { vector{2, 2}, // kernel_shape {}, // output_padding @@ -176,16 +206,17 @@ TEST(ConvTransposeTest, ConvTranspose_2D_C2) { 0.44524362f, 0.6056068f}; vector Y_shape = {1, 1, 4, 4}; - auto expected_vals = { + vector expected_vals = { 0.50678771f, 1.10413539f, 0.74340409f, 0.14989006f, 0.34063845f, 1.19294512f, 1.85030293f, 0.63518577f, 0.58575004f, 1.25774109f, 1.23472511f, 0.77670550f, 0.25844323f, 0.88953220f, 0.77098041f, 0.27468451f}; - TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); + TestConvTransposeOp(attrs, {GetTypedArray(X), GetTypedArray(W)}, + {X_shape, W_shape}, GetTypedArray(expected_vals), Y_shape); } -TEST(ConvTransposeTest, ConvTranspose_2D_Bias_1) { +TYPED_TEST(ConvTransposeTest, ConvTranspose_2D_Bias_1) { ConvTransposeOpAttributes attrs = { vector{3, 3}, // kernel_shape vector{0, 0}, // output_padding @@ -209,12 +240,13 @@ TEST(ConvTransposeTest, ConvTranspose_2D_Bias_1) { vector B = {0.04676145f}; vector B_shape = {1}; vector Y_shape = {1, 1, 5, 5}; - auto expected_vals = {-0.03781903f, -0.09041066f, 0.14239404f, 0.09704495f, -0.03399426f, - 0.08749044f, 0.35613984f, 0.07240347f, -0.27841991f, -0.00337578f, - 0.07770107f, -0.09561026f, 0.13388641f, 0.30945939f, 0.14015588f, - 0.13079405f, -0.00488365f, -0.06758944f, 0.45621645f, 0.01566098f, - 0.00703105f, 0.12956856f, 0.0103332f, 0.04221053f, -0.21318194f}; - TestConvTransposeOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); + vector expected_vals = {-0.03781903f, -0.09041066f, 0.14239404f, 0.09704495f, -0.03399426f, + 0.08749044f, 0.35613984f, 0.07240347f, -0.27841991f, -0.00337578f, + 0.07770107f, -0.09561026f, 0.13388641f, 0.30945939f, 0.14015588f, + 0.13079405f, -0.00488365f, -0.06758944f, 0.45621645f, 0.01566098f, + 0.00703105f, 0.12956856f, 0.0103332f, 0.04221053f, -0.21318194f}; + TestConvTransposeOp(attrs, {GetTypedArray(X), GetTypedArray(W), GetTypedArray(B)}, + {X_shape, W_shape, B_shape}, GetTypedArray(expected_vals), Y_shape); } TEST(ConvTransposeTest, ConvTranspose_2D_Bias_2) { @@ -247,22 +279,22 @@ TEST(ConvTransposeTest, ConvTranspose_2D_Bias_2) { vector B = {0.17402864f}; vector B_shape = {1}; vector Y_shape = {1, 1, 8, 8}; - auto expected_vals = {0.1695925f, 0.14171794f, 0.31368554f, 0.16113512f, - 0.15653302f, 0.033998f, 0.38345876f, 0.12173492f, - 0.05362644f, 0.35481372f, 0.09013268f, -0.06378071f, - 0.24394518f, 0.00222442f, 0.50842237f, -0.07341707f, - 0.17984779f, 0.35392997f, 0.03631867f, 0.16350585f, - 0.30338728f, 0.2088346f, 0.47435546f, 0.0147884f, - 0.20821247f, 0.08664516f, 0.03569011f, 0.16659322f, - 0.47522858f, 0.19675478f, -0.10781619f, 0.02401161f, - 0.0965334f, 0.1788421f, 0.36887163f, 0.2512877f, - 0.00254938f, 0.04799958f, 0.11982619f, 0.31525785f, - 0.12701407f, 0.19566584f, 0.31214368f, -0.10558143f, - 0.18591091f, 0.46830338f, 0.05418756f, 0.20530567f, - 0.07357728f, 0.39731777f, 0.1872202f, 0.08253923f, - 0.11266428f, 0.17892915f, 0.32709083f, 0.1860041f, - 0.16902491f, 0.3129794f, -0.01718347f, 0.28917417f, - 0.07588299f, 0.32025051f, 0.39891475f, -0.04581133f}; + vector expected_vals = {0.1695925f, 0.14171794f, 0.31368554f, 0.16113512f, + 0.15653302f, 0.033998f, 0.38345876f, 0.12173492f, + 0.05362644f, 0.35481372f, 0.09013268f, -0.06378071f, + 0.24394518f, 0.00222442f, 0.50842237f, -0.07341707f, + 0.17984779f, 0.35392997f, 0.03631867f, 0.16350585f, + 0.30338728f, 0.2088346f, 0.47435546f, 0.0147884f, + 0.20821247f, 0.08664516f, 0.03569011f, 0.16659322f, + 0.47522858f, 0.19675478f, -0.10781619f, 0.02401161f, + 0.0965334f, 0.1788421f, 0.36887163f, 0.2512877f, + 0.00254938f, 0.04799958f, 0.11982619f, 0.31525785f, + 0.12701407f, 0.19566584f, 0.31214368f, -0.10558143f, + 0.18591091f, 0.46830338f, 0.05418756f, 0.20530567f, + 0.07357728f, 0.39731777f, 0.1872202f, 0.08253923f, + 0.11266428f, 0.17892915f, 0.32709083f, 0.1860041f, + 0.16902491f, 0.3129794f, -0.01718347f, 0.28917417f, + 0.07588299f, 0.32025051f, 0.39891475f, -0.04581133f}; TestConvTransposeOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); } @@ -292,18 +324,18 @@ TEST(ConvTransposeTest, ConvTranspose_2D_OutputShape_1) { vector W_shape = {3, 3, 3, 3}; vector Y_shape = {1, 3, 4, 4}; - auto expected_vals = {12.0f, 18.0f, 18.0f, 12.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 12.0f, 18.0f, 18.0f, 12.0f, - 12.0f, 18.0f, 18.0f, 12.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 12.0f, 18.0f, 18.0f, 12.0f, - 12.0f, 18.0f, 18.0f, 12.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 12.0f, 18.0f, 18.0f, 12.0f}; + vector expected_vals = {12.0f, 18.0f, 18.0f, 12.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 12.0f, 18.0f, 18.0f, 12.0f, + 12.0f, 18.0f, 18.0f, 12.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 12.0f, 18.0f, 18.0f, 12.0f, + 12.0f, 18.0f, 18.0f, 12.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 12.0f, 18.0f, 18.0f, 12.0f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kQnnExecutionProvider}); @@ -338,12 +370,12 @@ TEST(ConvTransposeTest, ConvTranspose_1D_OutputShape_1_group_2_for_transpose_pat vector W_shape = {6, 3, 3}; vector Y_shape = {1, 6, 4}; - auto expected_vals = {6.0f, 9.0f, 9.0f, 6.0f, - 6.0f, 9.0f, 9.0f, 6.0f, - 6.0f, 9.0f, 9.0f, 6.0f, - 6.0f, 9.0f, 9.0f, 6.0f, - 6.0f, 9.0f, 9.0f, 6.0f, - 6.0f, 9.0f, 9.0f, 6.0f}; + vector expected_vals = {6.0f, 9.0f, 9.0f, 6.0f, + 6.0f, 9.0f, 9.0f, 6.0f, + 6.0f, 9.0f, 9.0f, 6.0f, + 6.0f, 9.0f, 9.0f, 6.0f, + 6.0f, 9.0f, 9.0f, 6.0f, + 6.0f, 9.0f, 9.0f, 6.0f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, OpTester::ExpectResult::kExpectSuccess, "", @@ -376,30 +408,30 @@ TEST(ConvTransposeTest, ConvTranspose_2D_OutputShape_1_group_2_for_transpose_pat vector W_shape = {6, 3, 3, 3}; vector Y_shape = {1, 6, 4, 4}; - auto expected_vals = {12.0f, 18.0f, 18.0f, 12.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 12.0f, 18.0f, 18.0f, 12.0f, - 12.0f, 18.0f, 18.0f, 12.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 12.0f, 18.0f, 18.0f, 12.0f, - 12.0f, 18.0f, 18.0f, 12.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 12.0f, 18.0f, 18.0f, 12.0f, // duplicate below - 12.0f, 18.0f, 18.0f, 12.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 12.0f, 18.0f, 18.0f, 12.0f, - 12.0f, 18.0f, 18.0f, 12.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 12.0f, 18.0f, 18.0f, 12.0f, - 12.0f, 18.0f, 18.0f, 12.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 18.0f, 27.0f, 27.0f, 18.0f, - 12.0f, 18.0f, 18.0f, 12.0f}; + vector expected_vals = {12.0f, 18.0f, 18.0f, 12.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 12.0f, 18.0f, 18.0f, 12.0f, + 12.0f, 18.0f, 18.0f, 12.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 12.0f, 18.0f, 18.0f, 12.0f, + 12.0f, 18.0f, 18.0f, 12.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 12.0f, 18.0f, 18.0f, 12.0f, // duplicate below + 12.0f, 18.0f, 18.0f, 12.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 12.0f, 18.0f, 18.0f, 12.0f, + 12.0f, 18.0f, 18.0f, 12.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 12.0f, 18.0f, 18.0f, 12.0f, + 12.0f, 18.0f, 18.0f, 12.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 18.0f, 27.0f, 27.0f, 18.0f, + 12.0f, 18.0f, 18.0f, 12.0f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, OpTester::ExpectResult::kExpectSuccess, "", @@ -424,7 +456,7 @@ TEST(ConvTransposeTest, ConvTranspose_2D_OutputShape_2) { vector B = {1.0f}; vector B_shape = {1}; vector Y_shape = {1, 1, 1, 14}; - auto expected_vals = {1.0f, 2.0f, 5.0f, 11.0f, 19.0f, 28.0f, 37.0f, 46.0f, 55.0f, 64.0f, 63.0f, 51.0f, 27.0f, 10.0f}; + vector expected_vals = {1.0f, 2.0f, 5.0f, 11.0f, 19.0f, 28.0f, 37.0f, 46.0f, 55.0f, 64.0f, 63.0f, 51.0f, 27.0f, 10.0f}; TestConvTransposeOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider, kCudaNHWCExecutionProvider, kQnnExecutionProvider}); @@ -449,8 +481,8 @@ TEST(ConvTransposeTest, ConvTranspose_2D_OutputShapeWithBatchSize) { vector B = {1.0f}; vector B_shape = {1}; vector Y_shape = {2, 1, 1, 14}; - auto expected_vals = {1.0f, 2.0f, 5.0f, 11.0f, 19.0f, 28.0f, 37.0f, 46.0f, 55.0f, 64.0f, 63.0f, 51.0f, 27.0f, 10.0f, - 11.0f, 32.0f, 65.0f, 91.0f, 109.0f, 118.0f, 127.0f, 136.0f, 145.0f, 154.0f, 143.0f, 111.0f, 57.0f, 20.0f}; + vector expected_vals = {1.0f, 2.0f, 5.0f, 11.0f, 19.0f, 28.0f, 37.0f, 46.0f, 55.0f, 64.0f, 63.0f, 51.0f, 27.0f, 10.0f, + 11.0f, 32.0f, 65.0f, 91.0f, 109.0f, 118.0f, 127.0f, 136.0f, 145.0f, 154.0f, 143.0f, 111.0f, 57.0f, 20.0f}; TestConvTransposeOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider, kCudaNHWCExecutionProvider, kQnnExecutionProvider}); @@ -475,8 +507,8 @@ TEST(ConvTransposeTest, ConvTranspose_InvalidKernelShape) { vector B = {1.0f}; vector B_shape = {1}; vector Y_shape = {2, 1, 1, 14}; - auto expected_vals = {1.0f, 2.0f, 5.0f, 11.0f, 19.0f, 28.0f, 37.0f, 46.0f, 55.0f, 64.0f, 63.0f, 51.0f, 27.0f, 10.0f, - 11.0f, 32.0f, 65.0f, 91.0f, 109.0f, 118.0f, 127.0f, 136.0f, 145.0f, 154.0f, 143.0f, 111.0f, 57.0f, 20.0f}; + vector expected_vals = {1.0f, 2.0f, 5.0f, 11.0f, 19.0f, 28.0f, 37.0f, 46.0f, 55.0f, 64.0f, 63.0f, 51.0f, 27.0f, 10.0f, + 11.0f, 32.0f, 65.0f, 91.0f, 109.0f, 118.0f, 127.0f, 136.0f, 145.0f, 154.0f, 143.0f, 111.0f, 57.0f, 20.0f}; TestConvTransposeOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, OpTester::ExpectResult::kExpectFailure, // error message will end in "W: {1,1,1,5}" or "W: {1,1,5,1} depending on whether NCHW or NHWC, @@ -502,7 +534,7 @@ TEST(ConvTransposeTest, ConvTranspose_onnx) { vector W = {0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17.}; vector W_shape = {1, 2, 3, 3}; vector Y_shape = {1, 2, 5, 5}; - auto expected_vals = { + vector expected_vals = { 0.f, 0.f, 1.f, 4.f, 4.f, 0.f, 6.f, 20.f, 26.f, 20.f, 9.f, 36.f, 84.f, 84.f, 57.f, @@ -533,7 +565,7 @@ TEST(ConvTransposeTest, ConvTranspose_onnx2) { vector W = {0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23.}; vector W_shape = {2, 3, 2, 2}; // this requires weight transpose vector Y_shape = {1, 3, 4, 4}; - auto expected_vals = { + vector expected_vals = { 108.f, 237.f, 263.f, 145.f, 270.f, 592.f, 652.f, 358.f, 354.f, 772.f, 832.f, 454.f, @@ -566,7 +598,7 @@ TEST(ConvTransposeTest, ConvTranspose_onnx_group) { vector W = {0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.0f}; vector W_shape = {16, 2, 1, 1}; vector Y_shape = {1, 8, 1, 1}; - auto expected_vals = {28.f, 34.f, 252.f, 274.f, 732.f, 770.f, 1468.f, 1522.f}; + vector expected_vals = {28.f, 34.f, 252.f, 274.f, 732.f, 770.f, 1468.f, 1522.f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } @@ -586,10 +618,10 @@ TEST(ConvTransposeTest, ConvTranspose_2D_Dilation_1) { vector W = {1.0f, 1.0f, 1.0f, 1.0f}; vector W_shape = {1, 1, 2, 2}; vector Y_shape = {1, 1, 4, 4}; - auto expected_vals = {11.0f, 12.0f, 11.0f, 12.0f, - 21.0f, 22.0f, 21.0f, 22.0f, - 11.0f, 12.0f, 11.0f, 12.0f, - 21.0f, 22.0f, 21.0f, 22.0f}; + vector expected_vals = {11.0f, 12.0f, 11.0f, 12.0f, + 21.0f, 22.0f, 21.0f, 22.0f, + 11.0f, 12.0f, 11.0f, 12.0f, + 21.0f, 22.0f, 21.0f, 22.0f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } @@ -609,11 +641,11 @@ TEST(ConvTransposeTest, ConvTranspose_2D_Dilation_2) { vector W = {1.0f, 1.0f, 1.0f, 1.0f}; vector W_shape = {1, 1, 2, 2}; vector Y_shape = {1, 1, 5, 5}; - auto expected_vals = {11.0f, 12.0f, 0.0f, 11.0f, 12.0f, - 21.0f, 22.0f, 0.0f, 21.0f, 22.0f, - 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, - 11.0f, 12.0f, 0.0f, 11.0f, 12.0f, - 21.0f, 22.0f, 0.0f, 21.0f, 22.0f}; + vector expected_vals = {11.0f, 12.0f, 0.0f, 11.0f, 12.0f, + 21.0f, 22.0f, 0.0f, 21.0f, 22.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 11.0f, 12.0f, 0.0f, 11.0f, 12.0f, + 21.0f, 22.0f, 0.0f, 21.0f, 22.0f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } @@ -633,11 +665,11 @@ TEST(ConvTransposeTest, ConvTranspose_2D_Dilation_3) { vector W = {7.0f, 2.0f, 1.0f, 9.0f}; vector W_shape = {1, 1, 2, 2}; vector Y_shape = {1, 1, 5, 5}; - auto expected_vals = {21.0f, 56.0f, 13.0f, 16.0f, 2.0f, - 63.0f, 35.0f, 67.0f, 10.0f, 14.0f, - 24.0f, 22.0f, 76.0f, 76.0f, 21.0f, - 9.0f, 5.0f, 88.0f, 45.0f, 63.0f, - 3.0f, 2.0f, 33.0f, 18.0f, 54.0f}; + vector expected_vals = {21.0f, 56.0f, 13.0f, 16.0f, 2.0f, + 63.0f, 35.0f, 67.0f, 10.0f, 14.0f, + 24.0f, 22.0f, 76.0f, 76.0f, 21.0f, + 9.0f, 5.0f, 88.0f, 45.0f, 63.0f, + 3.0f, 2.0f, 33.0f, 18.0f, 54.0f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } @@ -658,12 +690,12 @@ TEST(ConvTransposeTest, ConvTranspose_2D_Dilation_4) { vector W = {7.0f, 2.0f, 1.0f, 9.0f}; vector W_shape = {1, 1, 2, 2}; vector Y_shape = {1, 1, 6, 6}; - auto expected_vals = {21.0f, 56.0f, 7.0f, 6.0f, 16.0f, 2.0f, - 63.0f, 35.0f, 49.0f, 18.0f, 10.0f, 14.0f, - 21.0f, 14.0f, 42.0f, 6.0f, 4.0f, 12.0f, - 3.0f, 8.0f, 1.0f, 27.0f, 72.0f, 9.0f, - 9.0f, 5.0f, 7.0f, 81.0f, 45.0f, 63.0f, - 3.0f, 2.0f, 6.0f, 27.0f, 18.0f, 54.0f}; + vector expected_vals = {21.0f, 56.0f, 7.0f, 6.0f, 16.0f, 2.0f, + 63.0f, 35.0f, 49.0f, 18.0f, 10.0f, 14.0f, + 21.0f, 14.0f, 42.0f, 6.0f, 4.0f, 12.0f, + 3.0f, 8.0f, 1.0f, 27.0f, 72.0f, 9.0f, + 9.0f, 5.0f, 7.0f, 81.0f, 45.0f, 63.0f, + 3.0f, 2.0f, 6.0f, 27.0f, 18.0f, 54.0f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } @@ -684,9 +716,9 @@ TEST(ConvTransposeTest, ConvTranspose_2D_Dilation_AsymmetricPads_1) { vector W = {7.0f, 2.0f, 1.0f, 9.0f}; vector W_shape = {1, 1, 2, 2}; vector Y_shape = {1, 1, 3, 3}; - auto expected_vals = {42.0f, 6.0f, 4.0f, - 1.0f, 27.0f, 72.0f, - 7.0f, 81.0f, 45.0f}; + vector expected_vals = {42.0f, 6.0f, 4.0f, + 1.0f, 27.0f, 72.0f, + 7.0f, 81.0f, 45.0f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } @@ -707,9 +739,9 @@ TEST(ConvTransposeTest, ConvTranspose_2D_Dilation_AsymmetricPads_2) { vector W = {7.0f, 2.0f, 1.0f, 9.0f}; vector W_shape = {1, 1, 2, 2}; vector Y_shape = {1, 1, 3, 3}; - auto expected_vals = {35.0f, 49.0f, 18.0f, - 14.0f, 42.0f, 6.0f, - 8.0f, 1.0f, 27.0f}; + vector expected_vals = {35.0f, 49.0f, 18.0f, + 14.0f, 42.0f, 6.0f, + 8.0f, 1.0f, 27.0f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } @@ -730,10 +762,10 @@ TEST(ConvTransposeTest, ConvTranspose_2D_Dilation_AsymmetricPads_3) { vector W = {7.0f, 2.0f, 1.0f, 9.0f}; vector W_shape = {1, 1, 2, 2}; vector Y_shape = {1, 1, 4, 4}; - auto expected_vals = {42.0f, 6.0f, 4.0f, 12.0f, - 1.0f, 27.0f, 72.0f, 9.0f, - 7.0f, 81.0f, 45.0f, 63.0f, - 6.0f, 27.0f, 18.0f, 54.0f}; + vector expected_vals = {42.0f, 6.0f, 4.0f, 12.0f, + 1.0f, 27.0f, 72.0f, 9.0f, + 7.0f, 81.0f, 45.0f, 63.0f, + 6.0f, 27.0f, 18.0f, 54.0f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } @@ -754,10 +786,10 @@ TEST(ConvTransposeTest, ConvTranspose_2D_Dilation_AsymmetricPads_4) { vector W = {7.0f, 2.0f, 1.0f, 9.0f}; vector W_shape = {1, 1, 2, 2}; vector Y_shape = {1, 1, 4, 4}; - auto expected_vals = {21.0f, 56.0f, 7.0f, 6.0f, - 63.0f, 35.0f, 49.0f, 18.0f, - 21.0f, 14.0f, 42.0f, 6.0f, - 3.0f, 8.0f, 1.0f, 27.0f}; + vector expected_vals = {21.0f, 56.0f, 7.0f, 6.0f, + 63.0f, 35.0f, 49.0f, 18.0f, + 21.0f, 14.0f, 42.0f, 6.0f, + 3.0f, 8.0f, 1.0f, 27.0f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } @@ -778,16 +810,16 @@ TEST(ConvTransposeTest, ConvTranspose_2D_Dilation_Group_1) { vector W = {9.0f, 3.0f, 1.0f, 2.0f, 3.0f, 7.0f, 0.0f, 8.0f}; vector W_shape = {2, 1, 2, 2}; vector Y_shape = {1, 2, 5, 5}; - auto expected_vals = {27.0f, 72.0f, 18.0f, 24.0f, 3.0f, - 81.0f, 45.0f, 90.0f, 15.0f, 21.0f, - 30.0f, 26.0f, 43.0f, 22.0f, 11.0f, - 9.0f, 5.0f, 25.0f, 10.0f, 14.0f, - 3.0f, 2.0f, 9.0f, 4.0f, 6.0f, - 21.0f, 27.0f, 52.0f, 63.0f, 7.0f, - 15.0f, 6.0f, 44.0f, 14.0f, 21.0f, - 27.0f, 0.0f, 125.0f, 72.0f, 22.0f, - 0.0f, 0.0f, 40.0f, 16.0f, 24.0f, - 0.0f, 0.0f, 72.0f, 0.0f, 16.0f}; + vector expected_vals = {27.0f, 72.0f, 18.0f, 24.0f, 3.0f, + 81.0f, 45.0f, 90.0f, 15.0f, 21.0f, + 30.0f, 26.0f, 43.0f, 22.0f, 11.0f, + 9.0f, 5.0f, 25.0f, 10.0f, 14.0f, + 3.0f, 2.0f, 9.0f, 4.0f, 6.0f, + 21.0f, 27.0f, 52.0f, 63.0f, 7.0f, + 15.0f, 6.0f, 44.0f, 14.0f, 21.0f, + 27.0f, 0.0f, 125.0f, 72.0f, 22.0f, + 0.0f, 0.0f, 40.0f, 16.0f, 24.0f, + 0.0f, 0.0f, 72.0f, 0.0f, 16.0f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } @@ -808,7 +840,7 @@ TEST(ConvTransposeTest, ConvTranspose_DefaultStridesAndDilations) { vector W = {0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23.}; vector W_shape = {2, 3, 2, 2}; // this requires weight transpose vector Y_shape = {1, 3, 4, 4}; - auto expected_vals = { + vector expected_vals = { 108.f, 237.f, 263.f, 145.f, 270.f, 592.f, 652.f, 358.f, 354.f, 772.f, 832.f, 454.f, @@ -841,7 +873,7 @@ TEST(ConvTransposeTest, ConvTranspose_2D_NonDefaultStridesAndDilations) { vector W = {1., 1., 1., 1.}; vector W_shape = {1, 1, 1, 4}; vector Y_shape = {1, 1, 1, 12}; - auto expected_vals = {1.f, 0.f, 2.f, 1.f, 0.f, 2.f, 1.f, 0.f, 2.f, 1.f, 0.f, 2.f}; + vector expected_vals = {1.f, 0.f, 2.f, 1.f, 0.f, 2.f, 1.f, 0.f, 2.f, 1.f, 0.f, 2.f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } @@ -862,7 +894,7 @@ TEST(ConvTransposeTest, ConvTranspose_2D_NonDefaultStridesAndDilations_T) { vector W = {1., 1., 1., 1.}; vector W_shape = {1, 1, 4, 1}; vector Y_shape = {1, 1, 12, 1}; - auto expected_vals = {1.f, 0.f, 2.f, 1.f, 0.f, 2.f, 1.f, 0.f, 2.f, 1.f, 0.f, 2.f}; + vector expected_vals = {1.f, 0.f, 2.f, 1.f, 0.f, 2.f, 1.f, 0.f, 2.f, 1.f, 0.f, 2.f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } @@ -885,7 +917,7 @@ TEST(ConvTransposeTest, DimWithZero) { 0.04118127f, -0.44696793f, 0.06373066f}; vector W_shape = {1, 1, 3, 3}; vector Y_shape = {0, 1, 6, 6}; - initializer_list expected_vals = {}; + vector expected_vals = {}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, OpTester::ExpectResult::kExpectSuccess, "", @@ -952,75 +984,75 @@ TEST(ConvTransposeTest, ConvTranspose_3D) { vector B = {-0.11784090101718903f, -0.060990236699581146f}; vector Y_shape = {1, 2, 5, 6, 7}; - auto expected_vals = {-0.08241813f, -0.06676699f, -0.13411677f, -0.15724352f, -0.18772511f, -0.11080553f, -0.114930674f, - -0.0953398f, -0.111061305f, -0.0413035f, -0.10902196f, -0.071916685f, -0.102583766f, -0.13639182f, - -0.21214074f, -0.18799849f, -0.15122052f, 0.00434383f, -0.011207409f, -0.11604968f, -0.08378546f, - -0.1722928f, -0.044016793f, -0.1914465f, -0.16952308f, -0.39505655f, 0.080385f, -0.15767722f, - -0.060116887f, -0.16235165f, -0.075614765f, -0.14631891f, 0.05837299f, -0.31712085f, -0.13272354f, - -0.08320008f, -0.1967324f, -0.033198006f, -0.06718128f, -0.2568521f, 0.0314174f, -0.15864298f, - - -0.13070306f, -0.09003539f, -0.29147533f, -0.024966106f, 0.079442084f, -0.096389435f, -0.09941827f, - -0.3365072f, -0.4451772f, -0.13154466f, -0.08992967f, -0.16572365f, 0.06494926f, -0.21230686f, - -0.11307171f, -0.056943115f, -0.35291147f, -0.317253f, -0.070464894f, -0.6300395f, -0.031246513f, - 0.19395588f, 0.011135533f, 0.096916616f, -0.3942836f, -0.29872403f, 0.16881491f, -0.24881886f, - -0.038873613f, -0.032735735f, -0.21593677f, 0.088557824f, 0.13849314f, -0.30753696f, -0.07219358f, - -0.15177673f, -0.09156879f, -0.2286228f, 0.080623806f, -0.39201033f, 0.07819712f, -0.19924995f, - - -0.3376814f, -0.033524483f, 0.230105f, -0.0377952f, -0.12315659f, -0.28858358f, -0.13848148f, - -0.16134796f, 0.012239918f, 0.27276647f, 0.020731017f, -0.4651906f, -0.14341736f, -0.07956973f, - 0.1342433f, -0.16956037f, 0.310399f, 0.34338957f, -0.040192716f, 0.12504166f, -0.21490449f, - -0.15410437f, -0.1338158f, -0.39244395f, 0.29117042f, -0.26415867f, -0.4450379f, 0.0699404f, - 0.042872816f, -0.14961651f, -0.17582522f, -0.6919577f, -0.13723494f, -0.0681901f, -0.16183335f, - -0.0021959245f, -0.0418434f, -0.32134426f, 0.16967098f, -0.08680786f, -0.32077473f, 0.0066963434f, - - -0.114091426f, -0.041066267f, -0.080250874f, -0.72594404f, -0.30254412f, -0.03862554f, -0.27475363f, - 0.15282185f, -0.22887689f, -0.72043663f, -0.47111863f, -0.3755179f, -0.20074406f, 0.16101281f, - -0.20939936f, -0.21245953f, 0.11726546f, -0.8030824f, -0.5866715f, 0.20001571f, -0.26259118f, - 0.17054747f, 0.061063558f, -0.6348493f, 0.2620284f, -0.782919f, -0.31278569f, 0.2926497f, - -0.08745579f, 0.20646049f, -0.050303012f, -0.13460274f, 0.060659587f, -0.037006564f, -0.1292249f, - -0.11211421f, -0.038967483f, -0.21644044f, -0.24912538f, 0.08591288f, -0.40798867f, 0.006527111f, - - -0.049734667f, -0.3685795f, -0.11538547f, 0.27292788f, 0.025990233f, 0.119311824f, 0.0700129f, - -0.156443f, -0.13340846f, 0.10764159f, -0.014803357f, 0.046525866f, 0.015691683f, -0.1869241f, - 0.1004442f, -0.4885978f, -0.7585998f, -0.047841772f, -0.07570776f, 0.0471261f, 0.24483289f, - -0.16554686f, -0.1250152f, -0.15132052f, -0.08515984f, 0.14412321f, -0.1030291f, -0.2780918f, - 0.05803944f, -0.10257156f, -0.4341917f, -0.13150966f, -0.53996617f, -0.15628646f, 0.059058204f, - -0.11976162f, -0.022163756f, -0.13519828f, -0.20148787f, 0.16934697f, -0.14327072f, -0.2129095f, - - -0.107836396f, -0.0819309f, -0.06148723f, -0.0063935146f, -0.02425649f, -0.056219954f, -0.06095987f, - -0.14403576f, -0.025357183f, -0.15828207f, 0.012748428f, -0.16061643f, -0.03419252f, -0.05130991f, - -0.109983265f, -0.08312916f, -0.07035978f, -0.008285124f, -0.10610263f, -0.01489019f, -0.106886685f, - -0.007659614f, -0.2947925f, -0.09132287f, -0.040577132f, 0.089866154f, -0.24528673f, -0.055424154f, - 0.13783869f, 0.023674607f, -0.10545369f, -0.20873478f, -0.4685722f, 0.09418375f, -0.06684458f, - 0.0410614f, 0.04018917f, -0.15845582f, 0.06580096f, 0.070554025f, -0.19462511f, -0.03526502f, - - -0.02956047f, -0.16035908f, -0.0638171f, -0.261022f, -0.022948403f, 0.08353848f, -0.041173913f, - 0.04770004f, 0.091520615f, 0.006987013f, -0.39962748f, 0.23266485f, -0.32719564f, -0.12885109f, - -0.29559937f, -0.08031146f, 0.76168066f, 0.0009028502f, -0.4091536f, -0.14801738f, -0.17058557f, - -0.05754847f, 0.2955231f, -0.089874476f, 0.17254886f, -0.13203058f, -0.007648442f, 0.010943003f, - 0.04123217f, 0.26074114f, -0.24313056f, 0.1008903f, -0.26472318f, 0.01998391f, -0.03422378f, - -0.024659738f, 0.033793047f, -0.1998924f, -0.110185415f, 0.10620246f, -0.3435271f, 0.019390412f, - - 0.21691665f, -0.26076952f, -0.5040901f, 0.28383943f, -0.34750903f, -0.32484284f, -0.01734912f, - -0.08909689f, -0.0466362f, 0.21648785f, 0.06733417f, 0.009496197f, 0.18728223f, -0.35110205f, - -0.04908372f, -0.36729553f, -0.346236f, -0.13589534f, -0.16435221f, -0.16853788f, 0.12264759f, - -0.019215636f, -0.38316554f, 0.35669535f, -0.56980205f, -0.059346225f, 0.15008381f, -0.1751053f, - 0.059508912f, 0.116622455f, -0.32607535f, -0.22282779f, -0.29149055f, -0.3829086f, 0.15905643f, - -0.077926554f, 0.06549884f, -0.09004557f, -0.15897253f, 0.26810864f, -0.08931713f, -0.047756508f, - - -0.14657992f, 0.43070868f, -0.021787114f, -0.4532621f, 0.092385404f, -0.30126676f, -0.24893704f, - -0.10896815f, -0.14514503f, -0.21353528f, 0.018723361f, 0.037694372f, 0.11514955f, 0.13013864f, - -0.25713888f, -0.056000195f, -0.3505367f, 0.0836427f, -0.032017898f, -0.26742116f, -0.14740711f, - -0.13330215f, -0.18958306f, -0.08968873f, 0.014723815f, -0.20343366f, 0.3098968f, 0.114284225f, - -0.026738256f, -0.14110464f, -0.054464605f, -0.17529932f, -0.0030034669f, -0.050670102f, -0.04016705f, - -0.062238634f, -0.04886609f, -0.042247344f, -0.12185234f, 0.0357792f, -0.10265522f, -0.116296895f, - - -0.1035416f, -0.09126053f, 0.20045105f, 0.12366664f, 0.05460281f, 0.09944453f, -0.055443168f, - -0.09767935f, -0.040166672f, -0.01716708f, 0.020299219f, 0.02864775f, -0.07159522f, -0.04354491f, - -0.1390779f, -0.13270372f, 0.02992779f, -0.025869183f, 0.12530258f, 0.05101595f, -0.07891131f, - -0.1051311f, -0.093200594f, -0.10368025f, 0.047598884f, -0.12069465f, -0.098738566f, -0.042393237f, - -0.08531736f, -0.051284637f, -0.04354899f, -0.06810297f, -0.083224006f, -0.11702064f, -0.08514082f, - -0.06071842f, -0.07496775f, -0.03626109f, -0.07785503f, -0.07243007f, -0.041736744f, -0.052593358f}; + vector expected_vals = {-0.08241813f, -0.06676699f, -0.13411677f, -0.15724352f, -0.18772511f, -0.11080553f, -0.114930674f, + -0.0953398f, -0.111061305f, -0.0413035f, -0.10902196f, -0.071916685f, -0.102583766f, -0.13639182f, + -0.21214074f, -0.18799849f, -0.15122052f, 0.00434383f, -0.011207409f, -0.11604968f, -0.08378546f, + -0.1722928f, -0.044016793f, -0.1914465f, -0.16952308f, -0.39505655f, 0.080385f, -0.15767722f, + -0.060116887f, -0.16235165f, -0.075614765f, -0.14631891f, 0.05837299f, -0.31712085f, -0.13272354f, + -0.08320008f, -0.1967324f, -0.033198006f, -0.06718128f, -0.2568521f, 0.0314174f, -0.15864298f, + + -0.13070306f, -0.09003539f, -0.29147533f, -0.024966106f, 0.079442084f, -0.096389435f, -0.09941827f, + -0.3365072f, -0.4451772f, -0.13154466f, -0.08992967f, -0.16572365f, 0.06494926f, -0.21230686f, + -0.11307171f, -0.056943115f, -0.35291147f, -0.317253f, -0.070464894f, -0.6300395f, -0.031246513f, + 0.19395588f, 0.011135533f, 0.096916616f, -0.3942836f, -0.29872403f, 0.16881491f, -0.24881886f, + -0.038873613f, -0.032735735f, -0.21593677f, 0.088557824f, 0.13849314f, -0.30753696f, -0.07219358f, + -0.15177673f, -0.09156879f, -0.2286228f, 0.080623806f, -0.39201033f, 0.07819712f, -0.19924995f, + + -0.3376814f, -0.033524483f, 0.230105f, -0.0377952f, -0.12315659f, -0.28858358f, -0.13848148f, + -0.16134796f, 0.012239918f, 0.27276647f, 0.020731017f, -0.4651906f, -0.14341736f, -0.07956973f, + 0.1342433f, -0.16956037f, 0.310399f, 0.34338957f, -0.040192716f, 0.12504166f, -0.21490449f, + -0.15410437f, -0.1338158f, -0.39244395f, 0.29117042f, -0.26415867f, -0.4450379f, 0.0699404f, + 0.042872816f, -0.14961651f, -0.17582522f, -0.6919577f, -0.13723494f, -0.0681901f, -0.16183335f, + -0.0021959245f, -0.0418434f, -0.32134426f, 0.16967098f, -0.08680786f, -0.32077473f, 0.0066963434f, + + -0.114091426f, -0.041066267f, -0.080250874f, -0.72594404f, -0.30254412f, -0.03862554f, -0.27475363f, + 0.15282185f, -0.22887689f, -0.72043663f, -0.47111863f, -0.3755179f, -0.20074406f, 0.16101281f, + -0.20939936f, -0.21245953f, 0.11726546f, -0.8030824f, -0.5866715f, 0.20001571f, -0.26259118f, + 0.17054747f, 0.061063558f, -0.6348493f, 0.2620284f, -0.782919f, -0.31278569f, 0.2926497f, + -0.08745579f, 0.20646049f, -0.050303012f, -0.13460274f, 0.060659587f, -0.037006564f, -0.1292249f, + -0.11211421f, -0.038967483f, -0.21644044f, -0.24912538f, 0.08591288f, -0.40798867f, 0.006527111f, + + -0.049734667f, -0.3685795f, -0.11538547f, 0.27292788f, 0.025990233f, 0.119311824f, 0.0700129f, + -0.156443f, -0.13340846f, 0.10764159f, -0.014803357f, 0.046525866f, 0.015691683f, -0.1869241f, + 0.1004442f, -0.4885978f, -0.7585998f, -0.047841772f, -0.07570776f, 0.0471261f, 0.24483289f, + -0.16554686f, -0.1250152f, -0.15132052f, -0.08515984f, 0.14412321f, -0.1030291f, -0.2780918f, + 0.05803944f, -0.10257156f, -0.4341917f, -0.13150966f, -0.53996617f, -0.15628646f, 0.059058204f, + -0.11976162f, -0.022163756f, -0.13519828f, -0.20148787f, 0.16934697f, -0.14327072f, -0.2129095f, + + -0.107836396f, -0.0819309f, -0.06148723f, -0.0063935146f, -0.02425649f, -0.056219954f, -0.06095987f, + -0.14403576f, -0.025357183f, -0.15828207f, 0.012748428f, -0.16061643f, -0.03419252f, -0.05130991f, + -0.109983265f, -0.08312916f, -0.07035978f, -0.008285124f, -0.10610263f, -0.01489019f, -0.106886685f, + -0.007659614f, -0.2947925f, -0.09132287f, -0.040577132f, 0.089866154f, -0.24528673f, -0.055424154f, + 0.13783869f, 0.023674607f, -0.10545369f, -0.20873478f, -0.4685722f, 0.09418375f, -0.06684458f, + 0.0410614f, 0.04018917f, -0.15845582f, 0.06580096f, 0.070554025f, -0.19462511f, -0.03526502f, + + -0.02956047f, -0.16035908f, -0.0638171f, -0.261022f, -0.022948403f, 0.08353848f, -0.041173913f, + 0.04770004f, 0.091520615f, 0.006987013f, -0.39962748f, 0.23266485f, -0.32719564f, -0.12885109f, + -0.29559937f, -0.08031146f, 0.76168066f, 0.0009028502f, -0.4091536f, -0.14801738f, -0.17058557f, + -0.05754847f, 0.2955231f, -0.089874476f, 0.17254886f, -0.13203058f, -0.007648442f, 0.010943003f, + 0.04123217f, 0.26074114f, -0.24313056f, 0.1008903f, -0.26472318f, 0.01998391f, -0.03422378f, + -0.024659738f, 0.033793047f, -0.1998924f, -0.110185415f, 0.10620246f, -0.3435271f, 0.019390412f, + + 0.21691665f, -0.26076952f, -0.5040901f, 0.28383943f, -0.34750903f, -0.32484284f, -0.01734912f, + -0.08909689f, -0.0466362f, 0.21648785f, 0.06733417f, 0.009496197f, 0.18728223f, -0.35110205f, + -0.04908372f, -0.36729553f, -0.346236f, -0.13589534f, -0.16435221f, -0.16853788f, 0.12264759f, + -0.019215636f, -0.38316554f, 0.35669535f, -0.56980205f, -0.059346225f, 0.15008381f, -0.1751053f, + 0.059508912f, 0.116622455f, -0.32607535f, -0.22282779f, -0.29149055f, -0.3829086f, 0.15905643f, + -0.077926554f, 0.06549884f, -0.09004557f, -0.15897253f, 0.26810864f, -0.08931713f, -0.047756508f, + + -0.14657992f, 0.43070868f, -0.021787114f, -0.4532621f, 0.092385404f, -0.30126676f, -0.24893704f, + -0.10896815f, -0.14514503f, -0.21353528f, 0.018723361f, 0.037694372f, 0.11514955f, 0.13013864f, + -0.25713888f, -0.056000195f, -0.3505367f, 0.0836427f, -0.032017898f, -0.26742116f, -0.14740711f, + -0.13330215f, -0.18958306f, -0.08968873f, 0.014723815f, -0.20343366f, 0.3098968f, 0.114284225f, + -0.026738256f, -0.14110464f, -0.054464605f, -0.17529932f, -0.0030034669f, -0.050670102f, -0.04016705f, + -0.062238634f, -0.04886609f, -0.042247344f, -0.12185234f, 0.0357792f, -0.10265522f, -0.116296895f, + + -0.1035416f, -0.09126053f, 0.20045105f, 0.12366664f, 0.05460281f, 0.09944453f, -0.055443168f, + -0.09767935f, -0.040166672f, -0.01716708f, 0.020299219f, 0.02864775f, -0.07159522f, -0.04354491f, + -0.1390779f, -0.13270372f, 0.02992779f, -0.025869183f, 0.12530258f, 0.05101595f, -0.07891131f, + -0.1051311f, -0.093200594f, -0.10368025f, 0.047598884f, -0.12069465f, -0.098738566f, -0.042393237f, + -0.08531736f, -0.051284637f, -0.04354899f, -0.06810297f, -0.083224006f, -0.11702064f, -0.08514082f, + -0.06071842f, -0.07496775f, -0.03626109f, -0.07785503f, -0.07243007f, -0.041736744f, -0.052593358f}; TestConvTransposeOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, OpTester::ExpectResult::kExpectSuccess, "", @@ -1045,7 +1077,7 @@ TEST(ConvTransposeTest, ConvTranspose_1D_AsymmetricPads) { vector W = {1.0f, 1.0f, 1.0f, 1.0f}; vector W_shape = {1, 2, 2}; vector Y_shape = {1, 2, 4}; - auto expected_vals = {3.0f, 5.0f, 7.0f, 4.0f, 3.0f, 5.0f, 7.0f, 4.0f}; + vector expected_vals = {3.0f, 5.0f, 7.0f, 4.0f, 3.0f, 5.0f, 7.0f, 4.0f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kQnnExecutionProvider}); @@ -1068,7 +1100,7 @@ TEST(ConvTransposeTest, ConvTranspose_1D_AutoPad_SameUpper) { vector W = {1.0f, 1.0f, 1.0f, 1.0f}; vector W_shape = {1, 2, 2}; vector Y_shape = {1, 2, 4}; - auto expected_vals = {1.0f, 3.0f, 5.0f, 7.0f, 1.0f, 3.0f, 5.0f, 7.0f}; + vector expected_vals = {1.0f, 3.0f, 5.0f, 7.0f, 1.0f, 3.0f, 5.0f, 7.0f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, OpTester::ExpectResult::kExpectSuccess, "", @@ -1092,7 +1124,7 @@ TEST(ConvTransposeTest, ConvTranspose_1D_AutoPad_SameLower) { vector W = {1.0f, 1.0f, 1.0f, 1.0f}; vector W_shape = {1, 2, 2}; vector Y_shape = {1, 2, 4}; - auto expected_vals = {3.0f, 5.0f, 7.0f, 4.0f, 3.0f, 5.0f, 7.0f, 4.0f}; + vector expected_vals = {3.0f, 5.0f, 7.0f, 4.0f, 3.0f, 5.0f, 7.0f, 4.0f}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, OpTester::ExpectResult::kExpectSuccess, "", @@ -1125,19 +1157,19 @@ TEST(ConvTransposeTest, ConvTranspose_AutoPad_with_non_default_strides) { 1.0f, 1.0f, 1.0f}; vector W_shape = {1, 2, 3, 3}; - auto expected_vals = {0.0f, 0.0f, 1.0f, 1.0f, 3.0f, 2.0f, - 0.0f, 0.0f, 1.0f, 1.0f, 3.0f, 2.0f, - 3.0f, 3.0f, 8.0f, 5.0f, 12.0f, 7.0f, - 3.0f, 3.0f, 7.0f, 4.0f, 9.0f, 5.0f, - 9.0f, 9.0f, 20.0f, 11.0f, 24.0f, 13.0f, - 6.0f, 6.0f, 13.0f, 7.0f, 15.0f, 8.0f, - - 0.0f, 0.0f, 1.0f, 1.0f, 3.0f, 2.0f, - 0.0f, 0.0f, 1.0f, 1.0f, 3.0f, 2.0f, - 3.0f, 3.0f, 8.0f, 5.0f, 12.0f, 7.0f, - 3.0f, 3.0f, 7.0f, 4.0f, 9.0f, 5.0f, - 9.0f, 9.0f, 20.0f, 11.0f, 24.0f, 13.0f, - 6.0f, 6.0f, 13.0f, 7.0f, 15.0f, 8.0f}; + vector expected_vals = {0.0f, 0.0f, 1.0f, 1.0f, 3.0f, 2.0f, + 0.0f, 0.0f, 1.0f, 1.0f, 3.0f, 2.0f, + 3.0f, 3.0f, 8.0f, 5.0f, 12.0f, 7.0f, + 3.0f, 3.0f, 7.0f, 4.0f, 9.0f, 5.0f, + 9.0f, 9.0f, 20.0f, 11.0f, 24.0f, 13.0f, + 6.0f, 6.0f, 13.0f, 7.0f, 15.0f, 8.0f, + + 0.0f, 0.0f, 1.0f, 1.0f, 3.0f, 2.0f, + 0.0f, 0.0f, 1.0f, 1.0f, 3.0f, 2.0f, + 3.0f, 3.0f, 8.0f, 5.0f, 12.0f, 7.0f, + 3.0f, 3.0f, 7.0f, 4.0f, 9.0f, 5.0f, + 9.0f, 9.0f, 20.0f, 11.0f, 24.0f, 13.0f, + 6.0f, 6.0f, 13.0f, 7.0f, 15.0f, 8.0f}; vector Y_shape = {1, 2, 6, 6}; TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, @@ -1170,7 +1202,7 @@ TEST(ConvTransposeTest, SharedPrepackedWeights) { W.push_back(1.0f); test.AddInput("W", {6, 3, 3, 3}, W, true); // Trigger pre-packing - auto expected_vals = { + vector expected_vals = { 12.0f, 18.0f, 18.0f, diff --git a/onnxruntime/test/providers/cpu/nn/pool_fp16_op_test.cc b/onnxruntime/test/providers/cpu/nn/pool_fp16_op_test.cc index b033ddbca23d6..7d736d41e804b 100644 --- a/onnxruntime/test/providers/cpu/nn/pool_fp16_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/pool_fp16_op_test.cc @@ -3,7 +3,7 @@ #include "core/mlas/inc/mlas.h" -#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM) #include "core/providers/cpu/nn/pool.h" #include "gtest/gtest.h" @@ -567,4 +567,4 @@ TEST(PoolFp16Test, GlobalAveragePool) { } // namespace test } // namespace onnxruntime -#endif \ No newline at end of file +#endif diff --git a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc index 885fb11c6e999..a340f975ec91a 100644 --- a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc @@ -9,6 +9,13 @@ using namespace std; namespace onnxruntime { namespace test { +template +class PoolTest : public ::testing::Test { +}; + +using PoolTestTypes = ::testing::Types; +TYPED_TEST_SUITE(PoolTest, PoolTestTypes); + // Disable TensorRT on some of the tests because "pads" attribute is not supported TEST(PoolTest, MaxPool) { @@ -63,13 +70,15 @@ TEST(PoolTest, MaxPool) { // Only CUDA kernel has float 16 support // Disable for now, still investigating the issue with cudnn lib -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(COREML_ENABLE_MLPROGRAM) TEST(PoolTest, MaxPool_F16) { +#if defined(USE_CUDA) int min_cuda_architecture = 530; if (!HasCudaEnvironment(min_cuda_architecture)) { LOGS_DEFAULT(WARNING) << "Hardware NOT support FP16"; return; } +#endif OpTester test("MaxPool"); test.AddAttribute("auto_pad", ""); @@ -672,7 +681,7 @@ TEST(PoolTest, MaxPool_10_DilationPadding_3d) { {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); } -TEST(PoolTest, GlobalMaxPool) { +TYPED_TEST(PoolTest, GlobalMaxPool) { OpTester test("GlobalMaxPool"); std::vector x_vals = {0.19151945412158966, 0.6221087574958801, 0.43772774934768677, @@ -743,12 +752,23 @@ TEST(PoolTest, GlobalMaxPool) { std::vector expected_dims = {1, 3, 1, 1}; std::vector expected_vals = {0.9920814633369446, 0.9820047616958618, 0.9946538209915161}; - test.AddInput("X", x_dims, x_vals); - test.AddOutput("Y", expected_dims, expected_vals); + if constexpr (std::is_same::value) { + test.AddInput("X", x_dims, x_vals); + test.AddOutput("Y", expected_dims, expected_vals); + } else { + std::vector x_vals_fp16(x_vals.size()); + std::vector expected_vals_fp16(expected_vals.size()); + + ConvertFloatToMLFloat16(x_vals.data(), x_vals_fp16.data(), x_vals.size()); + ConvertFloatToMLFloat16(expected_vals.data(), expected_vals_fp16.data(), expected_vals.size()); + test.AddInput("X", x_dims, x_vals_fp16); + test.AddOutput("Y", expected_dims, expected_vals_fp16); + } + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}); } -TEST(PoolTest, GlobalMaxPool3D) { +TYPED_TEST(PoolTest, GlobalMaxPool3D) { OpTester test("GlobalMaxPool"); std::vector x_vals = {0.19151945412158966, 0.6221087574958801, 0.43772774934768677, @@ -819,8 +839,19 @@ TEST(PoolTest, GlobalMaxPool3D) { std::vector expected_dims = {1, 3, 1, 1, 1}; std::vector expected_vals = {0.9920814633369446, 0.9820047616958618, 0.9946538209915161}; - test.AddInput("X", x_dims, x_vals); - test.AddOutput("Y", expected_dims, expected_vals); + if constexpr (std::is_same::value) { + test.AddInput("X", x_dims, x_vals); + test.AddOutput("Y", expected_dims, expected_vals); + } else { + std::vector x_vals_fp16(x_vals.size()); + std::vector expected_vals_fp16(expected_vals.size()); + + ConvertFloatToMLFloat16(x_vals.data(), x_vals_fp16.data(), x_vals.size()); + ConvertFloatToMLFloat16(expected_vals.data(), expected_vals_fp16.data(), expected_vals.size()); + test.AddInput("X", x_dims, x_vals_fp16); + test.AddOutput("Y", expected_dims, expected_vals_fp16); + } + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } diff --git a/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc b/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc index 98f57f4573540..4a1888a5ca7d6 100644 --- a/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc @@ -7,6 +7,13 @@ namespace onnxruntime { namespace test { +template +class ConcatOpTest : public ::testing::Test { +}; + +using ConcatOpTestTypes = ::testing::Types; +TYPED_TEST_SUITE(ConcatOpTest, ConcatOpTestTypes); + // Some of the tests can't run on TensorrtExecutionProvider because of unsupported data types or limits // in its parser: axis >=0 && axis < nbDims. Those Tests will fallback to other EPs @@ -68,34 +75,45 @@ TEST(ConcatOpTest, Concat1D_2) { kQnnExecutionProvider}); // QNN: not support dynamic shape tensor } -TEST(ConcatOpTest, Concat2D_1) { +template +static std::vector GetTypedArray(std::vector inputs, [[maybe_unused]] T v = T(0.f)) { + if constexpr (std::is_same::value) { + return inputs; + } else { + std::vector inputs_fp16(inputs.size()); + ConvertFloatToMLFloat16(inputs.data(), inputs_fp16.data(), inputs.size()); + return inputs_fp16; + } +} + +TYPED_TEST(ConcatOpTest, Concat2D_1) { OpTester test("Concat"); test.AddAttribute("axis", int64_t{0}); std::vector dims{1, 4}; - test.AddInput("input1", dims, {11.0f, 12.0f, 13.0f, 14.0f}); - test.AddInput("input2", dims, {21.0f, 22.0f, 23.0f, 24.0f}); - test.AddInput("input3", dims, {31.0f, 32.0f, 33.0f, 34.0f}); - test.AddOutput("concat_result", {3, 4}, - {11.0f, 12.0f, 13.0f, 14.0f, - 21.0f, 22.0f, 23.0f, 24.0f, - 31.0f, 32.0f, 33.0f, 34.0f}); + test.AddInput("input1", dims, GetTypedArray({11.0f, 12.0f, 13.0f, 14.0f})); + test.AddInput("input2", dims, GetTypedArray({21.0f, 22.0f, 23.0f, 24.0f})); + test.AddInput("input3", dims, GetTypedArray({31.0f, 32.0f, 33.0f, 34.0f})); + test.AddOutput("concat_result", {3, 4}, + GetTypedArray({11.0f, 12.0f, 13.0f, 14.0f, + 21.0f, 22.0f, 23.0f, 24.0f, + 31.0f, 32.0f, 33.0f, 34.0f})); test.Run(); } -TEST(ConcatOpTest, Concat2D_2) { +TYPED_TEST(ConcatOpTest, Concat2D_2) { OpTester test("Concat"); test.AddAttribute("axis", int64_t{1}); std::vector dims{4, 1}; - test.AddInput("input1", dims, {11.0f, 21.0f, 31.0f, 41.0f}); - test.AddInput("input2", {4, 2}, {12.0f, 13.0f, 22.0f, 23.0f, 32.0f, 33.0f, 42.0f, 43.0f}); - test.AddInput("input3", dims, {14.0f, 24.0f, 34.0f, 44.0f}); - test.AddOutput("concat_result", {4, 4}, - {11.0f, 12.0f, 13.0f, 14.0f, - 21.0f, 22.0f, 23.0f, 24.0f, - 31.0f, 32.0f, 33.0f, 34.0f, - 41.0f, 42.0f, 43.0f, 44.0f}); + test.AddInput("input1", dims, GetTypedArray({11.0f, 21.0f, 31.0f, 41.0f})); + test.AddInput("input2", {4, 2}, GetTypedArray({12.0f, 13.0f, 22.0f, 23.0f, 32.0f, 33.0f, 42.0f, 43.0f})); + test.AddInput("input3", dims, GetTypedArray({14.0f, 24.0f, 34.0f, 44.0f})); + test.AddOutput("concat_result", {4, 4}, + GetTypedArray({11.0f, 12.0f, 13.0f, 14.0f, + 21.0f, 22.0f, 23.0f, 24.0f, + 31.0f, 32.0f, 33.0f, 34.0f, + 41.0f, 42.0f, 43.0f, 44.0f})); test.Run(); } diff --git a/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc b/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc index 540dc6dee68fb..05cfb5c13d689 100644 --- a/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc @@ -40,963 +40,970 @@ void RunTests(T& test, std::vector>&& execut // DO NOT edit following tests. They are generated by: // onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py -TEST(GridSampleTest, test_grid_sample_16_4D_nearest_zeros_align_corners) { +template +class GridSampleTest : public ::testing::Test { +}; + +using GridSampleTestTypes = ::testing::Types; +TYPED_TEST_SUITE(GridSampleTest, GridSampleTestTypes); + +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_nearest_zeros_align_corners) { OpTester test("GridSample", 16); std::string mode = "nearest"; std::string padding_mode = "zeros"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-1.125840f, -1.152360f, -0.250579f, -0.433879f, 0.848710f, 0.692009f, -0.316013f, -2.115219f, 0.468096f, -0.157712f, 1.443660f, 0.266049f, 0.166455f, 0.874382f, -0.143474f, -0.111609f, 0.931827f, 1.259009f, 2.004981f, 0.053737f, 0.618057f, -0.412802f, -0.841065f, -2.316042f}; + std::initializer_list X_data{TypeParam(-1.125840f), TypeParam(-1.152360f), TypeParam(-0.250579f), TypeParam(-0.433879f), TypeParam(0.848710f), TypeParam(0.692009f), TypeParam(-0.316013f), TypeParam(-2.115219f), TypeParam(0.468096f), TypeParam(-0.157712f), TypeParam(1.443660f), TypeParam(0.266049f), TypeParam(0.166455f), TypeParam(0.874382f), TypeParam(-0.143474f), TypeParam(-0.111609f), TypeParam(0.931827f), TypeParam(1.259009f), TypeParam(2.004981f), TypeParam(0.053737f), TypeParam(0.618057f), TypeParam(-0.412802f), TypeParam(-0.841065f), TypeParam(-2.316042f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.063110f, -0.615220f, 0.203022f, -1.120434f, -0.867079f, -0.618636f, 0.757125f, 0.703586f, -0.532194f, -0.043299f, 0.767473f, 1.192960f, 0.476259f, 0.162111f, 0.804584f, -0.706563f, 0.223613f, -0.930367f, -0.831703f, -0.619900f, 0.542968f, 0.482592f, -0.710823f, 0.362529f}; + std::initializer_list Grid_data{TypeParam(0.063110f), TypeParam(-0.615220f), TypeParam(0.203022f), TypeParam(-1.120434f), TypeParam(-0.867079f), TypeParam(-0.618636f), TypeParam(0.757125f), TypeParam(0.703586f), TypeParam(-0.532194f), TypeParam(-0.043299f), TypeParam(0.767473f), TypeParam(1.192960f), TypeParam(0.476259f), TypeParam(0.162111f), TypeParam(0.804584f), TypeParam(-0.706563f), TypeParam(0.223613f), TypeParam(-0.930367f), TypeParam(-0.831703f), TypeParam(-0.619900f), TypeParam(0.542968f), TypeParam(0.482592f), TypeParam(-0.710823f), TypeParam(0.362529f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-1.152360f, -1.152360f, -1.125840f, 0.692009f, -0.250579f, 0.692009f, -2.115219f, -2.115219f, -0.316013f, 0.266049f, 0.468096f, 0.266049f, -0.111609f, 0.874382f, 0.874382f, 0.166455f, -0.111609f, -0.143474f, -0.412802f, 0.053737f, 0.053737f, 2.004981f, -0.412802f, 0.618057f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-1.152360f), TypeParam(-1.152360f), TypeParam(-1.125840f), TypeParam(0.692009f), TypeParam(-0.250579f), TypeParam(0.692009f), TypeParam(-2.115219f), TypeParam(-2.115219f), TypeParam(-0.316013f), TypeParam(0.266049f), TypeParam(0.468096f), TypeParam(0.266049f), TypeParam(-0.111609f), TypeParam(0.874382f), TypeParam(0.874382f), TypeParam(0.166455f), TypeParam(-0.111609f), TypeParam(-0.143474f), TypeParam(-0.412802f), TypeParam(0.053737f), TypeParam(0.053737f), TypeParam(2.004981f), TypeParam(-0.412802f), TypeParam(0.618057f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_16_4D_nearest_zeros_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_nearest_zeros_no_align_corners) { OpTester test("GridSample", 16); std::string mode = "nearest"; std::string padding_mode = "zeros"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-0.569248f, 0.919971f, 1.110816f, 1.289874f, -1.478174f, 2.567233f, -0.473120f, 0.335551f, -0.003304f, -0.534441f, 1.168688f, 0.394503f, 1.941462f, 0.791498f, -0.020252f, -0.437170f, -1.535287f, -0.412679f, 0.966303f, 1.624783f, -0.365619f, -1.302440f, 0.099403f, 0.441822f}; + std::initializer_list X_data{TypeParam(-0.569248f), TypeParam(0.919971f), TypeParam(1.110816f), TypeParam(1.289874f), TypeParam(-1.478174f), TypeParam(2.567233f), TypeParam(-0.473120f), TypeParam(0.335551f), TypeParam(-0.003304f), TypeParam(-0.534441f), TypeParam(1.168688f), TypeParam(0.394503f), TypeParam(1.941462f), TypeParam(0.791498f), TypeParam(-0.020252f), TypeParam(-0.437170f), TypeParam(-1.535287f), TypeParam(-0.412679f), TypeParam(0.966303f), TypeParam(1.624783f), TypeParam(-0.365619f), TypeParam(-1.302440f), TypeParam(0.099403f), TypeParam(0.441822f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{-1.143118f, -0.021569f, -0.903671f, -0.925628f, -0.066120f, 0.180174f, -0.491436f, 0.712053f, -0.730247f, 1.088844f, 0.822360f, -1.011940f, -0.298661f, 0.054147f, 0.175081f, 0.284609f, 0.470914f, 0.071880f, -0.585515f, 0.567827f, -1.151099f, -0.711248f, -0.300396f, -0.584536f}; + std::initializer_list Grid_data{TypeParam(-1.143118f), TypeParam(-0.021569f), TypeParam(-0.903671f), TypeParam(-0.925628f), TypeParam(-0.066120f), TypeParam(0.180174f), TypeParam(-0.491436f), TypeParam(0.712053f), TypeParam(-0.730247f), TypeParam(1.088844f), TypeParam(0.822360f), TypeParam(-1.011940f), TypeParam(-0.298661f), TypeParam(0.054147f), TypeParam(0.175081f), TypeParam(0.284609f), TypeParam(0.470914f), TypeParam(0.071880f), TypeParam(-0.585515f), TypeParam(0.567827f), TypeParam(-1.151099f), TypeParam(-0.711248f), TypeParam(-0.300396f), TypeParam(-0.584536f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{0.000000f, -0.569248f, 1.110816f, -1.478174f, 0.000000f, 0.000000f, 0.000000f, -0.473120f, -0.003304f, 1.168688f, 0.000000f, 0.000000f, -0.020252f, -0.437170f, -0.437170f, -1.535287f, 0.000000f, 1.941462f, -0.365619f, -1.302440f, -1.302440f, 0.099403f, 0.000000f, 0.966303f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(0.000000f), TypeParam(-0.569248f), TypeParam(1.110816f), TypeParam(-1.478174f), TypeParam(0.000000f), TypeParam(0.000000f), TypeParam(0.000000f), TypeParam(-0.473120f), TypeParam(-0.003304f), TypeParam(1.168688f), TypeParam(0.000000f), TypeParam(0.000000f), TypeParam(-0.020252f), TypeParam(-0.437170f), TypeParam(-0.437170f), TypeParam(-1.535287f), TypeParam(0.000000f), TypeParam(1.941462f), TypeParam(-0.365619f), TypeParam(-1.302440f), TypeParam(-1.302440f), TypeParam(0.099403f), TypeParam(0.000000f), TypeParam(0.966303f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_16_4D_nearest_border_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_nearest_border_align_corners) { OpTester test("GridSample", 16); std::string mode = "nearest"; std::string padding_mode = "border"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-0.883376f, -0.418913f, -0.804826f, 0.565610f, 0.610365f, 0.466884f, 1.950657f, -1.063099f, -0.829367f, -1.407257f, 1.626847f, 0.172273f, -1.611502f, -0.479448f, -0.143351f, -0.317295f, 0.573655f, 0.997931f, 0.543609f, 0.078804f, 0.862860f, -0.019490f, 0.991047f, -0.777735f}; + std::initializer_list X_data{TypeParam(-0.883376f), TypeParam(-0.418913f), TypeParam(-0.804826f), TypeParam(0.565610f), TypeParam(0.610365f), TypeParam(0.466884f), TypeParam(1.950657f), TypeParam(-1.063099f), TypeParam(-0.829367f), TypeParam(-1.407257f), TypeParam(1.626847f), TypeParam(0.172273f), TypeParam(-1.611502f), TypeParam(-0.479448f), TypeParam(-0.143351f), TypeParam(-0.317295f), TypeParam(0.573655f), TypeParam(0.997931f), TypeParam(0.543609f), TypeParam(0.078804f), TypeParam(0.862860f), TypeParam(-0.019490f), TypeParam(0.991047f), TypeParam(-0.777735f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{-1.080070f, -0.080985f, 1.055303f, -0.489470f, 1.083604f, 0.434584f, -1.082953f, 0.759237f, -0.138473f, -0.535688f, 0.959584f, -0.969714f, 0.128766f, -0.251242f, 0.856935f, 0.334973f, 0.576606f, 0.423791f, -0.288570f, -0.252367f, -0.988898f, 0.650213f, 0.952774f, 0.821070f}; + std::initializer_list Grid_data{TypeParam(-1.080070f), TypeParam(-0.080985f), TypeParam(1.055303f), TypeParam(-0.489470f), TypeParam(1.083604f), TypeParam(0.434584f), TypeParam(-1.082953f), TypeParam(0.759237f), TypeParam(-0.138473f), TypeParam(-0.535688f), TypeParam(0.959584f), TypeParam(-0.969714f), TypeParam(0.128766f), TypeParam(-0.251242f), TypeParam(0.856935f), TypeParam(0.334973f), TypeParam(0.576606f), TypeParam(0.423791f), TypeParam(-0.288570f), TypeParam(-0.252367f), TypeParam(-0.988898f), TypeParam(0.650213f), TypeParam(0.952774f), TypeParam(0.821070f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-0.804826f, 0.565610f, 0.565610f, 0.610365f, -0.883376f, -0.418913f, -0.829367f, -1.407257f, -1.407257f, 1.626847f, 1.950657f, -1.063099f, -0.317295f, -0.317295f, -0.317295f, -0.143351f, 0.573655f, 0.997931f, -0.019490f, -0.019490f, -0.019490f, 0.862860f, 0.991047f, -0.777735f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.804826f), TypeParam(0.565610f), TypeParam(0.565610f), TypeParam(0.610365f), TypeParam(-0.883376f), TypeParam(-0.418913f), TypeParam(-0.829367f), TypeParam(-1.407257f), TypeParam(-1.407257f), TypeParam(1.626847f), TypeParam(1.950657f), TypeParam(-1.063099f), TypeParam(-0.317295f), TypeParam(-0.317295f), TypeParam(-0.317295f), TypeParam(-0.143351f), TypeParam(0.573655f), TypeParam(0.997931f), TypeParam(-0.019490f), TypeParam(-0.019490f), TypeParam(-0.019490f), TypeParam(0.862860f), TypeParam(0.991047f), TypeParam(-0.777735f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_16_4D_nearest_border_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_nearest_border_no_align_corners) { OpTester test("GridSample", 16); std::string mode = "nearest"; std::string padding_mode = "border"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-0.559630f, 0.533472f, 0.406887f, 0.394587f, 0.171511f, 0.876045f, -0.287087f, 1.021640f, 0.438649f, -0.010704f, 1.338354f, -0.279405f, -0.551834f, -2.889061f, -1.509981f, 1.024115f, 0.195393f, -0.737109f, 1.700101f, 0.346216f, 0.971125f, 1.450250f, -0.051909f, -0.628431f}; + std::initializer_list X_data{TypeParam(-0.559630f), TypeParam(0.533472f), TypeParam(0.406887f), TypeParam(0.394587f), TypeParam(0.171511f), TypeParam(0.876045f), TypeParam(-0.287087f), TypeParam(1.021640f), TypeParam(0.438649f), TypeParam(-0.010704f), TypeParam(1.338354f), TypeParam(-0.279405f), TypeParam(-0.551834f), TypeParam(-2.889061f), TypeParam(-1.509981f), TypeParam(1.024115f), TypeParam(0.195393f), TypeParam(-0.737109f), TypeParam(1.700101f), TypeParam(0.346216f), TypeParam(0.971125f), TypeParam(1.450250f), TypeParam(-0.051909f), TypeParam(-0.628431f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.149807f, 1.074831f, 0.734055f, -0.758657f, 0.538205f, -0.848275f, -0.508590f, 0.352947f, 0.396231f, 0.900274f, -0.386299f, 0.001921f, 0.617788f, -1.160511f, 0.867577f, -0.992307f, 0.016539f, -0.204020f, -0.632008f, 0.158605f, 0.992302f, -0.350783f, -0.712433f, -0.443807f}; + std::initializer_list Grid_data{TypeParam(0.149807f), TypeParam(1.074831f), TypeParam(0.734055f), TypeParam(-0.758657f), TypeParam(0.538205f), TypeParam(-0.848275f), TypeParam(-0.508590f), TypeParam(0.352947f), TypeParam(0.396231f), TypeParam(0.900274f), TypeParam(-0.386299f), TypeParam(0.001921f), TypeParam(0.617788f), TypeParam(-1.160511f), TypeParam(0.867577f), TypeParam(-0.992307f), TypeParam(0.016539f), TypeParam(-0.204020f), TypeParam(-0.632008f), TypeParam(0.158605f), TypeParam(0.992302f), TypeParam(-0.350783f), TypeParam(-0.712433f), TypeParam(-0.443807f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{0.876045f, 0.533472f, 0.533472f, 0.171511f, 0.876045f, 0.406887f, -0.279405f, 1.021640f, 1.021640f, 1.338354f, -0.279405f, 0.438649f, -2.889061f, -2.889061f, 1.024115f, -1.509981f, -2.889061f, -0.551834f, 0.346216f, 0.346216f, 1.450250f, 0.971125f, 0.346216f, 1.700101f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(0.876045f), TypeParam(0.533472f), TypeParam(0.533472f), TypeParam(0.171511f), TypeParam(0.876045f), TypeParam(0.406887f), TypeParam(-0.279405f), TypeParam(1.021640f), TypeParam(1.021640f), TypeParam(1.338354f), TypeParam(-0.279405f), TypeParam(0.438649f), TypeParam(-2.889061f), TypeParam(-2.889061f), TypeParam(1.024115f), TypeParam(-1.509981f), TypeParam(-2.889061f), TypeParam(-0.551834f), TypeParam(0.346216f), TypeParam(0.346216f), TypeParam(1.450250f), TypeParam(0.971125f), TypeParam(0.346216f), TypeParam(1.700101f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_16_4D_nearest_reflection_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_nearest_reflection_align_corners) { OpTester test("GridSample", 16); std::string mode = "nearest"; std::string padding_mode = "reflection"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-0.039373f, -0.801472f, -0.495544f, -0.361514f, 0.585113f, -1.156007f, -0.143365f, -0.194741f, -0.906885f, -0.591838f, 0.150785f, -1.041149f, -0.720534f, -2.214754f, -0.683730f, 0.516358f, 0.792848f, 0.083228f, 0.422800f, -1.868747f, -1.105713f, 0.143731f, 0.583597f, 1.348155f}; + std::initializer_list X_data{TypeParam(-0.039373f), TypeParam(-0.801472f), TypeParam(-0.495544f), TypeParam(-0.361514f), TypeParam(0.585113f), TypeParam(-1.156007f), TypeParam(-0.143365f), TypeParam(-0.194741f), TypeParam(-0.906885f), TypeParam(-0.591838f), TypeParam(0.150785f), TypeParam(-1.041149f), TypeParam(-0.720534f), TypeParam(-2.214754f), TypeParam(-0.683730f), TypeParam(0.516358f), TypeParam(0.792848f), TypeParam(0.083228f), TypeParam(0.422800f), TypeParam(-1.868747f), TypeParam(-1.105713f), TypeParam(0.143731f), TypeParam(0.583597f), TypeParam(1.348155f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.829854f, -0.893309f, 0.491599f, -0.403504f, -0.578962f, 0.215574f, -0.623348f, 0.276486f, 0.235657f, -0.890987f, 0.199798f, 0.511115f, 0.474997f, -0.151054f, -0.983745f, -0.184985f, 0.416769f, -0.437853f, 0.455497f, 0.799155f, -0.626582f, 0.011834f, 0.496199f, 0.094053f}; + std::initializer_list Grid_data{TypeParam(0.829854f), TypeParam(-0.893309f), TypeParam(0.491599f), TypeParam(-0.403504f), TypeParam(-0.578962f), TypeParam(0.215574f), TypeParam(-0.623348f), TypeParam(0.276486f), TypeParam(0.235657f), TypeParam(-0.890987f), TypeParam(0.199798f), TypeParam(0.511115f), TypeParam(0.474997f), TypeParam(-0.151054f), TypeParam(-0.983745f), TypeParam(-0.184985f), TypeParam(0.416769f), TypeParam(-0.437853f), TypeParam(0.455497f), TypeParam(0.799155f), TypeParam(-0.626582f), TypeParam(0.011834f), TypeParam(0.496199f), TypeParam(0.094053f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-0.801472f, -0.361514f, -0.495544f, -0.495544f, -0.801472f, -1.156007f, -0.194741f, -0.591838f, -0.906885f, -0.906885f, -0.194741f, -1.041149f, 0.516358f, -0.683730f, 0.516358f, 0.083228f, -0.683730f, 0.516358f, 0.143731f, -1.105713f, 0.143731f, 1.348155f, -1.105713f, 0.143731f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.801472f), TypeParam(-0.361514f), TypeParam(-0.495544f), TypeParam(-0.495544f), TypeParam(-0.801472f), TypeParam(-1.156007f), TypeParam(-0.194741f), TypeParam(-0.591838f), TypeParam(-0.906885f), TypeParam(-0.906885f), TypeParam(-0.194741f), TypeParam(-1.041149f), TypeParam(0.516358f), TypeParam(-0.683730f), TypeParam(0.516358f), TypeParam(0.083228f), TypeParam(-0.683730f), TypeParam(0.516358f), TypeParam(0.143731f), TypeParam(-1.105713f), TypeParam(0.143731f), TypeParam(1.348155f), TypeParam(-1.105713f), TypeParam(0.143731f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_16_4D_nearest_reflection_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_nearest_reflection_no_align_corners) { OpTester test("GridSample", 16); std::string mode = "nearest"; std::string padding_mode = "reflection"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-0.129230f, -0.054595f, 0.408347f, 1.126366f, 1.935057f, 1.007685f, 1.004642f, -0.433520f, -0.562711f, -0.832754f, -1.395545f, -0.399295f, -0.309940f, -0.056062f, 0.517413f, -1.596237f, 0.356960f, -2.297482f, -0.871083f, -1.674028f, 0.563055f, -1.435067f, 0.719400f, -1.370747f}; + std::initializer_list X_data{TypeParam(-0.129230f), TypeParam(-0.054595f), TypeParam(0.408347f), TypeParam(1.126366f), TypeParam(1.935057f), TypeParam(1.007685f), TypeParam(1.004642f), TypeParam(-0.433520f), TypeParam(-0.562711f), TypeParam(-0.832754f), TypeParam(-1.395545f), TypeParam(-0.399295f), TypeParam(-0.309940f), TypeParam(-0.056062f), TypeParam(0.517413f), TypeParam(-1.596237f), TypeParam(0.356960f), TypeParam(-2.297482f), TypeParam(-0.871083f), TypeParam(-1.674028f), TypeParam(0.563055f), TypeParam(-1.435067f), TypeParam(0.719400f), TypeParam(-1.370747f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{-0.811910f, -1.183845f, -0.963667f, 0.947364f, 0.649243f, 1.125859f, 0.961345f, -1.071655f, -0.818917f, -0.193899f, -0.779319f, 0.833276f, -0.907209f, -0.585482f, -1.159310f, -0.681295f, 0.986973f, 0.982512f, 0.859005f, 0.926553f, 1.067024f, -0.307276f, 0.528003f, 1.069117f}; + std::initializer_list Grid_data{TypeParam(-0.811910f), TypeParam(-1.183845f), TypeParam(-0.963667f), TypeParam(0.947364f), TypeParam(0.649243f), TypeParam(1.125859f), TypeParam(0.961345f), TypeParam(-1.071655f), TypeParam(-0.818917f), TypeParam(-0.193899f), TypeParam(-0.779319f), TypeParam(0.833276f), TypeParam(-0.907209f), TypeParam(-0.585482f), TypeParam(-1.159310f), TypeParam(-0.681295f), TypeParam(0.986973f), TypeParam(0.982512f), TypeParam(0.859005f), TypeParam(0.926553f), TypeParam(1.067024f), TypeParam(-0.307276f), TypeParam(0.528003f), TypeParam(1.069117f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-0.129230f, 1.935057f, 1.007685f, -0.054595f, 0.408347f, 1.935057f, 1.004642f, -1.395545f, -0.399295f, -0.433520f, -0.562711f, -1.395545f, -0.309940f, -0.309940f, -2.297482f, -2.297482f, -1.596237f, -2.297482f, -0.871083f, -0.871083f, -1.370747f, -1.370747f, -1.435067f, -1.370747f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.129230f), TypeParam(1.935057f), TypeParam(1.007685f), TypeParam(-0.054595f), TypeParam(0.408347f), TypeParam(1.935057f), TypeParam(1.004642f), TypeParam(-1.395545f), TypeParam(-0.399295f), TypeParam(-0.433520f), TypeParam(-0.562711f), TypeParam(-1.395545f), TypeParam(-0.309940f), TypeParam(-0.309940f), TypeParam(-2.297482f), TypeParam(-2.297482f), TypeParam(-1.596237f), TypeParam(-2.297482f), TypeParam(-0.871083f), TypeParam(-0.871083f), TypeParam(-1.370747f), TypeParam(-1.370747f), TypeParam(-1.435067f), TypeParam(-1.370747f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_zeros_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_zeros_align_corners) { OpTester test("GridSample", 16); std::string mode = "bilinear"; std::string padding_mode = "zeros"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{0.294201f, 0.797322f, 1.264215f, 0.935492f, 0.545464f, -1.537389f, 0.312439f, 0.740060f, -0.575326f, -1.432532f, -0.666175f, 1.017438f, -2.241368f, 0.437349f, -0.555362f, -0.057943f, 0.658583f, 0.992938f, -0.206548f, -0.244841f, -0.380599f, 1.131112f, -0.090205f, -0.897900f}; + std::initializer_list X_data{TypeParam(0.294201f), TypeParam(0.797322f), TypeParam(1.264215f), TypeParam(0.935492f), TypeParam(0.545464f), TypeParam(-1.537389f), TypeParam(0.312439f), TypeParam(0.740060f), TypeParam(-0.575326f), TypeParam(-1.432532f), TypeParam(-0.666175f), TypeParam(1.017438f), TypeParam(-2.241368f), TypeParam(0.437349f), TypeParam(-0.555362f), TypeParam(-0.057943f), TypeParam(0.658583f), TypeParam(0.992938f), TypeParam(-0.206548f), TypeParam(-0.244841f), TypeParam(-0.380599f), TypeParam(1.131112f), TypeParam(-0.090205f), TypeParam(-0.897900f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.595248f, -1.096726f, -0.214731f, -0.891773f, -0.512023f, 0.432352f, -0.852156f, 0.446072f, 1.018534f, 0.078706f, -0.799785f, -0.429942f, 0.262037f, -0.914782f, 0.596172f, -1.089444f, -1.153552f, -1.165993f, -0.243436f, 0.806920f, -1.135775f, 0.997425f, -0.480027f, 0.351461f}; + std::initializer_list Grid_data{TypeParam(0.595248f), TypeParam(-1.096726f), TypeParam(-0.214731f), TypeParam(-0.891773f), TypeParam(-0.512023f), TypeParam(0.432352f), TypeParam(-0.852156f), TypeParam(0.446072f), TypeParam(1.018534f), TypeParam(0.078706f), TypeParam(-0.799785f), TypeParam(-0.429942f), TypeParam(0.262037f), TypeParam(-0.914782f), TypeParam(0.596172f), TypeParam(-1.089444f), TypeParam(-1.153552f), TypeParam(-1.165993f), TypeParam(-0.243436f), TypeParam(0.806920f), TypeParam(-1.135775f), TypeParam(0.997425f), TypeParam(-0.480027f), TypeParam(0.351461f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{0.628229f, 0.561377f, 0.688215f, 0.861459f, 0.733996f, 0.850061f, 0.590307f, 0.329661f, -0.555725f, -0.595435f, -1.228216f, -0.224152f, -0.524667f, -0.094262f, -1.725798f, 0.562584f, 0.610959f, -0.014286f, -0.162194f, -0.215901f, -0.159037f, -0.282404f, -0.084779f, -0.097448f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(0.628229f), TypeParam(0.561377f), TypeParam(0.688215f), TypeParam(0.861459f), TypeParam(0.733996f), TypeParam(0.850061f), TypeParam(0.590307f), TypeParam(0.329661f), TypeParam(-0.555725f), TypeParam(-0.595435f), TypeParam(-1.228216f), TypeParam(-0.224152f), TypeParam(-0.524667f), TypeParam(-0.094262f), TypeParam(-1.725798f), TypeParam(0.562584f), TypeParam(0.610959f), TypeParam(-0.014286f), TypeParam(-0.162194f), TypeParam(-0.215901f), TypeParam(-0.159037f), TypeParam(-0.282404f), TypeParam(-0.084779f), TypeParam(-0.097448f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_zeros_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_zeros_no_align_corners) { OpTester test("GridSample", 16); std::string mode = "bilinear"; std::string padding_mode = "zeros"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-1.199109f, -0.025686f, 1.802375f, -1.059653f, 3.402826f, -0.568670f, -0.475489f, 1.743163f, 1.060884f, -0.015953f, 1.275653f, 0.009457f, -0.369450f, 1.218198f, 0.255044f, 0.273993f, 1.404381f, 1.082878f, 0.788966f, -0.137615f, 0.122478f, -1.076701f, -0.650897f, -1.619658f}; + std::initializer_list X_data{TypeParam(-1.199109f), TypeParam(-0.025686f), TypeParam(1.802375f), TypeParam(-1.059653f), TypeParam(3.402826f), TypeParam(-0.568670f), TypeParam(-0.475489f), TypeParam(1.743163f), TypeParam(1.060884f), TypeParam(-0.015953f), TypeParam(1.275653f), TypeParam(0.009457f), TypeParam(-0.369450f), TypeParam(1.218198f), TypeParam(0.255044f), TypeParam(0.273993f), TypeParam(1.404381f), TypeParam(1.082878f), TypeParam(0.788966f), TypeParam(-0.137615f), TypeParam(0.122478f), TypeParam(-1.076701f), TypeParam(-0.650897f), TypeParam(-1.619658f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.038587f, -0.371014f, -0.260918f, 0.159481f, 0.594851f, -0.840708f, 1.007133f, -0.130476f, -1.005535f, -0.649269f, 1.061781f, 1.097433f, -1.111536f, 0.846358f, 0.601391f, 0.710302f, 1.015835f, -0.646740f, 0.378931f, 0.491080f, -0.354592f, 0.401584f, -0.345256f, 0.741914f}; + std::initializer_list Grid_data{TypeParam(0.038587f), TypeParam(-0.371014f), TypeParam(-0.260918f), TypeParam(0.159481f), TypeParam(0.594851f), TypeParam(-0.840708f), TypeParam(1.007133f), TypeParam(-0.130476f), TypeParam(-1.005535f), TypeParam(-0.649269f), TypeParam(1.061781f), TypeParam(1.097433f), TypeParam(-1.111536f), TypeParam(0.846358f), TypeParam(0.601391f), TypeParam(0.710302f), TypeParam(1.015835f), TypeParam(-0.646740f), TypeParam(0.378931f), TypeParam(0.491080f), TypeParam(-0.354592f), TypeParam(0.401584f), TypeParam(-0.345256f), TypeParam(0.741914f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-0.199899f, 1.437523f, -0.017180f, -0.422530f, -0.554188f, -0.088180f, 0.613663f, 0.843979f, 1.165913f, 0.161823f, -0.215288f, 0.001466f, 0.398506f, 0.909392f, 0.576145f, 0.897902f, 0.920312f, 1.201733f, -0.184698f, -1.360176f, -0.080218f, -1.352020f, -0.497572f, -0.710420f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.199899f), TypeParam(1.437523f), TypeParam(-0.017180f), TypeParam(-0.422530f), TypeParam(-0.554188f), TypeParam(-0.088180f), TypeParam(0.613663f), TypeParam(0.843979f), TypeParam(1.165913f), TypeParam(0.161823f), TypeParam(-0.215288f), TypeParam(0.001466f), TypeParam(0.398506f), TypeParam(0.909392f), TypeParam(0.576145f), TypeParam(0.897902f), TypeParam(0.920312f), TypeParam(1.201733f), TypeParam(-0.184698f), TypeParam(-1.360176f), TypeParam(-0.080218f), TypeParam(-1.352020f), TypeParam(-0.497572f), TypeParam(-0.710420f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_border_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_border_align_corners) { OpTester test("GridSample", 16); std::string mode = "bilinear"; std::string padding_mode = "border"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-0.546073f, -0.630178f, -0.634650f, 0.974665f, 0.209843f, 0.029890f, 1.709235f, -0.725759f, -0.876951f, 0.522287f, 0.462005f, -1.329269f, -0.295974f, 1.371414f, 0.973846f, 0.765543f, -0.403897f, -0.326279f, 0.748218f, -0.195299f, 0.676756f, -0.080633f, 0.158123f, 0.099984f}; + std::initializer_list X_data{TypeParam(-0.546073f), TypeParam(-0.630178f), TypeParam(-0.634650f), TypeParam(0.974665f), TypeParam(0.209843f), TypeParam(0.029890f), TypeParam(1.709235f), TypeParam(-0.725759f), TypeParam(-0.876951f), TypeParam(0.522287f), TypeParam(0.462005f), TypeParam(-1.329269f), TypeParam(-0.295974f), TypeParam(1.371414f), TypeParam(0.973846f), TypeParam(0.765543f), TypeParam(-0.403897f), TypeParam(-0.326279f), TypeParam(0.748218f), TypeParam(-0.195299f), TypeParam(0.676756f), TypeParam(-0.080633f), TypeParam(0.158123f), TypeParam(0.099984f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{1.182462f, -0.759228f, 0.230068f, -0.103567f, -0.252788f, -0.268017f, 0.762529f, 0.057356f, -1.168338f, -0.708432f, -0.409080f, 0.603860f, -0.776560f, 1.131504f, -0.267275f, -0.215474f, 0.940270f, 0.603129f, 1.017745f, 0.694133f, -0.364025f, -0.796167f, -0.089284f, 0.993165f}; + std::initializer_list Grid_data{TypeParam(1.182462f), TypeParam(-0.759228f), TypeParam(0.230068f), TypeParam(-0.103567f), TypeParam(-0.252788f), TypeParam(-0.268017f), TypeParam(0.762529f), TypeParam(0.057356f), TypeParam(-1.168338f), TypeParam(-0.708432f), TypeParam(-0.409080f), TypeParam(0.603860f), TypeParam(-0.776560f), TypeParam(1.131504f), TypeParam(-0.267275f), TypeParam(-0.215474f), TypeParam(0.940270f), TypeParam(0.603129f), TypeParam(1.017745f), TypeParam(0.694133f), TypeParam(-0.364025f), TypeParam(-0.796167f), TypeParam(-0.089284f), TypeParam(0.993165f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-0.243777f, 0.256440f, -0.179228f, 0.741578f, -0.571899f, 0.031558f, -0.425264f, 0.007242f, -0.044977f, 0.271677f, 0.955187f, -0.224230f, -0.395226f, 0.771988f, 0.108104f, 0.007673f, 0.371491f, -0.360026f, 0.151628f, 0.399982f, 0.038327f, 0.044739f, 0.445689f, 0.133017f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.243777f), TypeParam(0.256440f), TypeParam(-0.179228f), TypeParam(0.741578f), TypeParam(-0.571899f), TypeParam(0.031558f), TypeParam(-0.425264f), TypeParam(0.007242f), TypeParam(-0.044977f), TypeParam(0.271677f), TypeParam(0.955187f), TypeParam(-0.224230f), TypeParam(-0.395226f), TypeParam(0.771988f), TypeParam(0.108104f), TypeParam(0.007673f), TypeParam(0.371491f), TypeParam(-0.360026f), TypeParam(0.151628f), TypeParam(0.399982f), TypeParam(0.038327f), TypeParam(0.044739f), TypeParam(0.445689f), TypeParam(0.133017f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_border_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_border_no_align_corners) { OpTester test("GridSample", 16); std::string mode = "bilinear"; std::string padding_mode = "border"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-0.873307f, 0.004261f, -1.257887f, -1.084466f, 0.752979f, 0.323648f, -0.275010f, 1.305612f, -0.009480f, -0.831312f, -0.556290f, 2.070567f, 0.710039f, -0.146461f, -0.746745f, 0.725842f, 0.403461f, 0.234374f, 0.173281f, 1.724145f, -0.408946f, 0.782749f, -1.520847f, -0.314686f}; + std::initializer_list X_data{TypeParam(-0.873307f), TypeParam(0.004261f), TypeParam(-1.257887f), TypeParam(-1.084466f), TypeParam(0.752979f), TypeParam(0.323648f), TypeParam(-0.275010f), TypeParam(1.305612f), TypeParam(-0.009480f), TypeParam(-0.831312f), TypeParam(-0.556290f), TypeParam(2.070567f), TypeParam(0.710039f), TypeParam(-0.146461f), TypeParam(-0.746745f), TypeParam(0.725842f), TypeParam(0.403461f), TypeParam(0.234374f), TypeParam(0.173281f), TypeParam(1.724145f), TypeParam(-0.408946f), TypeParam(0.782749f), TypeParam(-1.520847f), TypeParam(-0.314686f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.605180f, 0.169896f, 1.021029f, 0.161312f, -0.555188f, 1.135200f, 0.284017f, -1.170817f, -0.341630f, -0.817401f, 1.052104f, -0.198175f, -1.093830f, -0.075436f, 0.753615f, 0.311761f, 0.379445f, 0.111448f, 0.447382f, -0.292382f, -0.477360f, -1.121650f, -0.904004f, 0.520083f}; + std::initializer_list Grid_data{TypeParam(0.605180f), TypeParam(0.169896f), TypeParam(1.021029f), TypeParam(0.161312f), TypeParam(-0.555188f), TypeParam(1.135200f), TypeParam(0.284017f), TypeParam(-1.170817f), TypeParam(-0.341630f), TypeParam(-0.817401f), TypeParam(1.052104f), TypeParam(-0.198175f), TypeParam(-1.093830f), TypeParam(-0.075436f), TypeParam(0.753615f), TypeParam(0.311761f), TypeParam(0.379445f), TypeParam(0.111448f), TypeParam(0.447382f), TypeParam(-0.292382f), TypeParam(-0.477360f), TypeParam(-1.121650f), TypeParam(-0.904004f), TypeParam(0.520083f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-0.725617f, -0.743749f, 0.752979f, -0.185279f, -0.734326f, -0.760828f, -0.091786f, -0.129152f, -0.556290f, 0.964224f, -0.024687f, -0.196084f, -0.581904f, 0.496011f, 0.499240f, 0.319537f, 0.690648f, 0.150559f, -0.343065f, 0.269544f, 0.455333f, 1.124628f, 0.208392f, -1.276367f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.725617f), TypeParam(-0.743749f), TypeParam(0.752979f), TypeParam(-0.185279f), TypeParam(-0.734326f), TypeParam(-0.760828f), TypeParam(-0.091786f), TypeParam(-0.129152f), TypeParam(-0.556290f), TypeParam(0.964224f), TypeParam(-0.024687f), TypeParam(-0.196084f), TypeParam(-0.581904f), TypeParam(0.496011f), TypeParam(0.499240f), TypeParam(0.319537f), TypeParam(0.690648f), TypeParam(0.150559f), TypeParam(-0.343065f), TypeParam(0.269544f), TypeParam(0.455333f), TypeParam(1.124628f), TypeParam(0.208392f), TypeParam(-1.276367f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_reflection_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_reflection_align_corners) { OpTester test("GridSample", 16); std::string mode = "bilinear"; std::string padding_mode = "reflection"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{0.540757f, -0.947807f, 0.202144f, -0.350748f, 0.545005f, 1.541211f, 0.600239f, -0.338015f, -1.080823f, -1.391537f, -0.352570f, 1.560770f, -0.822488f, -2.140920f, 0.099553f, -0.697505f, 0.665352f, -2.256198f, -1.002236f, -1.395144f, 0.415783f, 0.268104f, -0.151752f, 0.794042f}; + std::initializer_list X_data{TypeParam(0.540757f), TypeParam(-0.947807f), TypeParam(0.202144f), TypeParam(-0.350748f), TypeParam(0.545005f), TypeParam(1.541211f), TypeParam(0.600239f), TypeParam(-0.338015f), TypeParam(-1.080823f), TypeParam(-1.391537f), TypeParam(-0.352570f), TypeParam(1.560770f), TypeParam(-0.822488f), TypeParam(-2.140920f), TypeParam(0.099553f), TypeParam(-0.697505f), TypeParam(0.665352f), TypeParam(-2.256198f), TypeParam(-1.002236f), TypeParam(-1.395144f), TypeParam(0.415783f), TypeParam(0.268104f), TypeParam(-0.151752f), TypeParam(0.794042f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{1.051960f, -0.798975f, -0.129852f, -0.064453f, 0.535452f, 0.820411f, -0.190205f, -0.994177f, 0.594591f, 0.358958f, 0.482039f, -0.740241f, 0.772315f, 1.136586f, 0.104126f, -1.120858f, 0.842388f, -0.889742f, 0.275846f, 0.174381f, -0.561644f, 0.417835f, -1.073319f, 0.273311f}; + std::initializer_list Grid_data{TypeParam(1.051960f), TypeParam(-0.798975f), TypeParam(-0.129852f), TypeParam(-0.064453f), TypeParam(0.535452f), TypeParam(0.820411f), TypeParam(-0.190205f), TypeParam(-0.994177f), TypeParam(0.594591f), TypeParam(0.358958f), TypeParam(0.482039f), TypeParam(-0.740241f), TypeParam(0.772315f), TypeParam(1.136586f), TypeParam(0.104126f), TypeParam(-1.120858f), TypeParam(0.842388f), TypeParam(-0.889742f), TypeParam(0.275846f), TypeParam(0.174381f), TypeParam(-0.561644f), TypeParam(0.417835f), TypeParam(-1.073319f), TypeParam(0.273311f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-0.793997f, -0.042818f, 1.034663f, -0.061725f, 0.327743f, -0.470152f, -0.528701f, -1.125254f, 0.678924f, 0.212033f, -0.430627f, -0.410903f, -1.743740f, -1.404122f, -1.882401f, -0.546577f, -0.033295f, 0.203686f, 0.631537f, -1.031405f, -1.182924f, 0.344248f, 0.246420f, 0.266212f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.793997f), TypeParam(-0.042818f), TypeParam(1.034663f), TypeParam(-0.061725f), TypeParam(0.327743f), TypeParam(-0.470152f), TypeParam(-0.528701f), TypeParam(-1.125254f), TypeParam(0.678924f), TypeParam(0.212033f), TypeParam(-0.430627f), TypeParam(-0.410903f), TypeParam(-1.743740f), TypeParam(-1.404122f), TypeParam(-1.882401f), TypeParam(-0.546577f), TypeParam(-0.033295f), TypeParam(0.203686f), TypeParam(0.631537f), TypeParam(-1.031405f), TypeParam(-1.182924f), TypeParam(0.344248f), TypeParam(0.246420f), TypeParam(0.266212f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_reflection_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_reflection_no_align_corners) { OpTester test("GridSample", 16); std::string mode = "bilinear"; std::string padding_mode = "reflection"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{0.584178f, 1.050431f, 1.285579f, -1.616520f, -0.768962f, -1.220462f, 0.573128f, 0.699197f, -1.654887f, 0.493267f, -0.615042f, 1.311865f, 0.788249f, -1.232951f, 0.454381f, -1.436621f, 0.711631f, 0.554599f, -0.807529f, 1.680131f, 0.597634f, -0.238890f, -0.345997f, 1.770104f}; + std::initializer_list X_data{TypeParam(0.584178f), TypeParam(1.050431f), TypeParam(1.285579f), TypeParam(-1.616520f), TypeParam(-0.768962f), TypeParam(-1.220462f), TypeParam(0.573128f), TypeParam(0.699197f), TypeParam(-1.654887f), TypeParam(0.493267f), TypeParam(-0.615042f), TypeParam(1.311865f), TypeParam(0.788249f), TypeParam(-1.232951f), TypeParam(0.454381f), TypeParam(-1.436621f), TypeParam(0.711631f), TypeParam(0.554599f), TypeParam(-0.807529f), TypeParam(1.680131f), TypeParam(0.597634f), TypeParam(-0.238890f), TypeParam(-0.345997f), TypeParam(1.770104f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.564800f, 1.031186f, 0.795913f, -0.629473f, -0.131544f, -0.377622f, -0.964948f, 0.000496f, 0.902922f, 1.011019f, 0.111961f, 0.272548f, -0.519506f, 0.905811f, -0.499330f, -0.833583f, 0.184792f, 0.719262f, -1.081910f, 1.084761f, 0.431677f, -0.840735f, -0.258489f, 1.041096f}; + std::initializer_list Grid_data{TypeParam(0.564800f), TypeParam(1.031186f), TypeParam(0.795913f), TypeParam(-0.629473f), TypeParam(-0.131544f), TypeParam(-0.377622f), TypeParam(-0.964948f), TypeParam(0.000496f), TypeParam(0.902922f), TypeParam(1.011019f), TypeParam(0.111961f), TypeParam(0.272548f), TypeParam(-0.519506f), TypeParam(0.905811f), TypeParam(-0.499330f), TypeParam(-0.833583f), TypeParam(0.184792f), TypeParam(0.719262f), TypeParam(-1.081910f), TypeParam(1.084761f), TypeParam(0.431677f), TypeParam(-0.840735f), TypeParam(-0.258489f), TypeParam(1.041096f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-1.220462f, 0.901641f, 0.521980f, 1.284051f, -1.220462f, -0.717235f, 1.311865f, 0.687708f, -0.023386f, -1.654114f, 1.311865f, 0.029458f, 0.711631f, 0.786895f, 0.604097f, 0.711631f, -1.094857f, 0.673706f, -0.345997f, -0.805863f, 1.103092f, -0.345997f, 1.510167f, 0.165064f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-1.220462f), TypeParam(0.901641f), TypeParam(0.521980f), TypeParam(1.284051f), TypeParam(-1.220462f), TypeParam(-0.717235f), TypeParam(1.311865f), TypeParam(0.687708f), TypeParam(-0.023386f), TypeParam(-1.654114f), TypeParam(1.311865f), TypeParam(0.029458f), TypeParam(0.711631f), TypeParam(0.786895f), TypeParam(0.604097f), TypeParam(0.711631f), TypeParam(-1.094857f), TypeParam(0.673706f), TypeParam(-0.345997f), TypeParam(-0.805863f), TypeParam(1.103092f), TypeParam(-0.345997f), TypeParam(1.510167f), TypeParam(0.165064f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_zeros_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_zeros_align_corners) { OpTester test("GridSample", 16); std::string mode = "bicubic"; std::string padding_mode = "zeros"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{0.497417f, 0.268522f, 1.476879f, 0.354795f, 1.624709f, 0.593423f, -1.725412f, -0.622016f, -0.466707f, -0.319962f, 0.701868f, 0.494252f, -0.630165f, 0.548236f, 1.042740f, 0.253800f, -2.667303f, 1.379165f, -0.519418f, 0.672783f, -0.005627f, -0.180192f, -0.018395f, 0.998084f}; + std::initializer_list X_data{TypeParam(0.497417f), TypeParam(0.268522f), TypeParam(1.476879f), TypeParam(0.354795f), TypeParam(1.624709f), TypeParam(0.593423f), TypeParam(-1.725412f), TypeParam(-0.622016f), TypeParam(-0.466707f), TypeParam(-0.319962f), TypeParam(0.701868f), TypeParam(0.494252f), TypeParam(-0.630165f), TypeParam(0.548236f), TypeParam(1.042740f), TypeParam(0.253800f), TypeParam(-2.667303f), TypeParam(1.379165f), TypeParam(-0.519418f), TypeParam(0.672783f), TypeParam(-0.005627f), TypeParam(-0.180192f), TypeParam(-0.018395f), TypeParam(0.998084f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.213755f, 0.141747f, -0.562622f, -0.414594f, 0.325025f, -0.834438f, 0.197995f, 0.519270f, -0.472884f, 0.996769f, -0.078973f, 0.544455f, 1.188368f, -0.366802f, 0.652090f, -0.343235f, -0.175288f, -0.203365f, -0.007455f, -0.453322f, 0.281264f, 0.045216f, 0.760668f, -0.242886f}; + std::initializer_list Grid_data{TypeParam(0.213755f), TypeParam(0.141747f), TypeParam(-0.562622f), TypeParam(-0.414594f), TypeParam(0.325025f), TypeParam(-0.834438f), TypeParam(0.197995f), TypeParam(0.519270f), TypeParam(-0.472884f), TypeParam(0.996769f), TypeParam(-0.078973f), TypeParam(0.544455f), TypeParam(1.188368f), TypeParam(-0.366802f), TypeParam(0.652090f), TypeParam(-0.343235f), TypeParam(-0.175288f), TypeParam(-0.203365f), TypeParam(-0.007455f), TypeParam(-0.453322f), TypeParam(0.281264f), TypeParam(0.045216f), TypeParam(0.760668f), TypeParam(-0.242886f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{1.007407f, 1.068583f, 0.492134f, 1.222040f, 1.576835f, 1.464183f, -0.238652f, -1.242164f, -1.156880f, 0.279082f, 0.744912f, 0.338287f, 0.215322f, 0.388598f, 0.866571f, 0.556826f, 0.608617f, 0.326312f, 0.044527f, -0.028766f, -0.136528f, -0.084880f, -0.121429f, -0.105516f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(1.007407f), TypeParam(1.068583f), TypeParam(0.492134f), TypeParam(1.222040f), TypeParam(1.576835f), TypeParam(1.464183f), TypeParam(-0.238652f), TypeParam(-1.242164f), TypeParam(-1.156880f), TypeParam(0.279082f), TypeParam(0.744912f), TypeParam(0.338287f), TypeParam(0.215322f), TypeParam(0.388598f), TypeParam(0.866571f), TypeParam(0.556826f), TypeParam(0.608617f), TypeParam(0.326312f), TypeParam(0.044527f), TypeParam(-0.028766f), TypeParam(-0.136528f), TypeParam(-0.084880f), TypeParam(-0.121429f), TypeParam(-0.105516f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_zeros_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_zeros_no_align_corners) { OpTester test("GridSample", 16); std::string mode = "bicubic"; std::string padding_mode = "zeros"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-1.065470f, 0.402578f, -0.405242f, -0.583366f, -0.258523f, -0.605559f, -0.188242f, 0.959607f, 1.189619f, -0.179522f, -1.823240f, -0.051351f, -1.636092f, -2.510569f, -1.238273f, -0.929619f, -0.058536f, 0.772879f, 0.468944f, 0.259886f, 0.757624f, -2.041813f, -0.552378f, 0.626977f}; + std::initializer_list X_data{TypeParam(-1.065470f), TypeParam(0.402578f), TypeParam(-0.405242f), TypeParam(-0.583366f), TypeParam(-0.258523f), TypeParam(-0.605559f), TypeParam(-0.188242f), TypeParam(0.959607f), TypeParam(1.189619f), TypeParam(-0.179522f), TypeParam(-1.823240f), TypeParam(-0.051351f), TypeParam(-1.636092f), TypeParam(-2.510569f), TypeParam(-1.238273f), TypeParam(-0.929619f), TypeParam(-0.058536f), TypeParam(0.772879f), TypeParam(0.468944f), TypeParam(0.259886f), TypeParam(0.757624f), TypeParam(-2.041813f), TypeParam(-0.552378f), TypeParam(0.626977f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{-1.199809f, 0.061445f, -0.035546f, 0.180524f, 0.919500f, 1.166411f, -0.711939f, -0.074825f, -0.480808f, -1.105975f, -0.873191f, 1.126273f, 0.699673f, 0.644581f, 0.666892f, -0.953375f, 0.126023f, 1.116858f, -0.669703f, 1.067513f, 0.315406f, 0.844252f, -0.514065f, 0.553221f}; + std::initializer_list Grid_data{TypeParam(-1.199809f), TypeParam(0.061445f), TypeParam(-0.035546f), TypeParam(0.180524f), TypeParam(0.919500f), TypeParam(1.166411f), TypeParam(-0.711939f), TypeParam(-0.074825f), TypeParam(-0.480808f), TypeParam(-1.105975f), TypeParam(-0.873191f), TypeParam(1.126273f), TypeParam(0.699673f), TypeParam(0.644581f), TypeParam(0.666892f), TypeParam(-0.953375f), TypeParam(0.126023f), TypeParam(1.116858f), TypeParam(-0.669703f), TypeParam(1.067513f), TypeParam(0.315406f), TypeParam(0.844252f), TypeParam(-0.514065f), TypeParam(0.553221f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-0.086429f, -0.590424f, -0.090572f, -0.393926f, -0.379182f, -0.031455f, 0.347836f, 0.182097f, 0.050161f, 1.154870f, -0.134312f, -0.509844f, 0.697346f, -1.440179f, 0.264668f, 0.021389f, 0.729883f, -0.236038f, 0.576661f, 0.348301f, 0.149351f, -0.327477f, 0.607344f, -0.405680f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.086429f), TypeParam(-0.590424f), TypeParam(-0.090572f), TypeParam(-0.393926f), TypeParam(-0.379182f), TypeParam(-0.031455f), TypeParam(0.347836f), TypeParam(0.182097f), TypeParam(0.050161f), TypeParam(1.154870f), TypeParam(-0.134312f), TypeParam(-0.509844f), TypeParam(0.697346f), TypeParam(-1.440179f), TypeParam(0.264668f), TypeParam(0.021389f), TypeParam(0.729883f), TypeParam(-0.236038f), TypeParam(0.576661f), TypeParam(0.348301f), TypeParam(0.149351f), TypeParam(-0.327477f), TypeParam(0.607344f), TypeParam(-0.405680f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_border_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_border_align_corners) { OpTester test("GridSample", 16); std::string mode = "bicubic"; std::string padding_mode = "border"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{0.203585f, -1.032829f, 1.130481f, -0.570301f, -2.100938f, 0.389922f, 0.087343f, -0.857360f, 1.193520f, -0.019760f, 0.280285f, 1.811013f, 1.838673f, 0.164184f, 1.436009f, 0.167011f, -1.139939f, -0.029833f, -0.009878f, 0.079750f, 0.216590f, -0.265852f, -0.528116f, -0.451915f}; + std::initializer_list X_data{TypeParam(0.203585f), TypeParam(-1.032829f), TypeParam(1.130481f), TypeParam(-0.570301f), TypeParam(-2.100938f), TypeParam(0.389922f), TypeParam(0.087343f), TypeParam(-0.857360f), TypeParam(1.193520f), TypeParam(-0.019760f), TypeParam(0.280285f), TypeParam(1.811013f), TypeParam(1.838673f), TypeParam(0.164184f), TypeParam(1.436009f), TypeParam(0.167011f), TypeParam(-1.139939f), TypeParam(-0.029833f), TypeParam(-0.009878f), TypeParam(0.079750f), TypeParam(0.216590f), TypeParam(-0.265852f), TypeParam(-0.528116f), TypeParam(-0.451915f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.797796f, -1.010726f, 0.868577f, -1.132977f, 0.268082f, -0.786042f, -0.476635f, 0.212483f, -0.471816f, -0.189867f, -1.137389f, -1.131448f, 0.464836f, -0.507934f, -0.730068f, -0.473499f, -0.981082f, -0.959280f, 0.718047f, 0.609891f, 0.159844f, -0.655512f, 0.399241f, 0.053910f}; + std::initializer_list Grid_data{TypeParam(0.797796f), TypeParam(-1.010726f), TypeParam(0.868577f), TypeParam(-1.132977f), TypeParam(0.268082f), TypeParam(-0.786042f), TypeParam(-0.476635f), TypeParam(0.212483f), TypeParam(-0.471816f), TypeParam(-0.189867f), TypeParam(-1.137389f), TypeParam(-1.131448f), TypeParam(0.464836f), TypeParam(-0.507934f), TypeParam(-0.730068f), TypeParam(-0.473499f), TypeParam(-0.981082f), TypeParam(-0.959280f), TypeParam(0.718047f), TypeParam(0.609891f), TypeParam(0.159844f), TypeParam(-0.655512f), TypeParam(0.399241f), TypeParam(0.053910f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-0.934180f, -1.004565f, -0.467118f, 0.384839f, 0.792549f, 0.188357f, -0.785741f, -0.871727f, -0.372851f, 0.958270f, 0.751528f, 0.046397f, 0.598629f, 1.686400f, 1.817043f, 0.015806f, 0.866266f, 0.480930f, -0.013358f, 0.152904f, -0.001292f, -0.385043f, 0.030959f, -0.152332f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.934180f), TypeParam(-1.004565f), TypeParam(-0.467118f), TypeParam(0.384839f), TypeParam(0.792549f), TypeParam(0.188357f), TypeParam(-0.785741f), TypeParam(-0.871727f), TypeParam(-0.372851f), TypeParam(0.958270f), TypeParam(0.751528f), TypeParam(0.046397f), TypeParam(0.598629f), TypeParam(1.686400f), TypeParam(1.817043f), TypeParam(0.015806f), TypeParam(0.866266f), TypeParam(0.480930f), TypeParam(-0.013358f), TypeParam(0.152904f), TypeParam(-0.001292f), TypeParam(-0.385043f), TypeParam(0.030959f), TypeParam(-0.152332f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_border_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_border_no_align_corners) { OpTester test("GridSample", 16); std::string mode = "bicubic"; std::string padding_mode = "border"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-0.427361f, 0.814325f, -1.412076f, -0.099774f, 0.074936f, 0.590322f, 0.398556f, -0.635891f, -1.081747f, -0.330179f, 0.271759f, -1.089819f, -0.746656f, -0.942538f, -1.251568f, -1.730282f, -0.722323f, 0.525964f, -0.436259f, -0.188952f, -0.499550f, 1.502071f, -0.014112f, 1.194050f}; + std::initializer_list X_data{TypeParam(-0.427361f), TypeParam(0.814325f), TypeParam(-1.412076f), TypeParam(-0.099774f), TypeParam(0.074936f), TypeParam(0.590322f), TypeParam(0.398556f), TypeParam(-0.635891f), TypeParam(-1.081747f), TypeParam(-0.330179f), TypeParam(0.271759f), TypeParam(-1.089819f), TypeParam(-0.746656f), TypeParam(-0.942538f), TypeParam(-1.251568f), TypeParam(-1.730282f), TypeParam(-0.722323f), TypeParam(0.525964f), TypeParam(-0.436259f), TypeParam(-0.188952f), TypeParam(-0.499550f), TypeParam(1.502071f), TypeParam(-0.014112f), TypeParam(1.194050f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{-0.102021f, -0.935855f, -0.007380f, -0.996053f, -0.258157f, 0.695455f, -0.834420f, -0.808862f, -0.293012f, -0.328961f, 0.203145f, 0.199219f, 0.608516f, -0.826657f, -0.084685f, 0.671149f, 1.037966f, -0.087535f, -0.694344f, 0.344955f, 0.683373f, -0.749700f, -0.696352f, 0.530398f}; + std::initializer_list Grid_data{TypeParam(-0.102021f), TypeParam(-0.935855f), TypeParam(-0.007380f), TypeParam(-0.996053f), TypeParam(-0.258157f), TypeParam(0.695455f), TypeParam(-0.834420f), TypeParam(-0.808862f), TypeParam(-0.293012f), TypeParam(-0.328961f), TypeParam(0.203145f), TypeParam(0.199219f), TypeParam(0.608516f), TypeParam(-0.826657f), TypeParam(-0.084685f), TypeParam(0.671149f), TypeParam(1.037966f), TypeParam(-0.087535f), TypeParam(-0.694344f), TypeParam(0.344955f), TypeParam(0.683373f), TypeParam(-0.749700f), TypeParam(-0.696352f), TypeParam(0.530398f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{0.154701f, 0.273277f, 0.226316f, -0.467055f, -0.820643f, -0.311691f, 0.084699f, -0.052970f, 0.001158f, 0.679701f, -0.467804f, -0.607116f, -0.871407f, -0.210613f, -1.860685f, -1.059387f, -0.902250f, -0.918798f, -0.360562f, 0.476049f, 1.499304f, -0.418396f, -0.298854f, -0.235927f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(0.154701f), TypeParam(0.273277f), TypeParam(0.226316f), TypeParam(-0.467055f), TypeParam(-0.820643f), TypeParam(-0.311691f), TypeParam(0.084699f), TypeParam(-0.052970f), TypeParam(0.001158f), TypeParam(0.679701f), TypeParam(-0.467804f), TypeParam(-0.607116f), TypeParam(-0.871407f), TypeParam(-0.210613f), TypeParam(-1.860685f), TypeParam(-1.059387f), TypeParam(-0.902250f), TypeParam(-0.918798f), TypeParam(-0.360562f), TypeParam(0.476049f), TypeParam(1.499304f), TypeParam(-0.418396f), TypeParam(-0.298854f), TypeParam(-0.235927f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_reflection_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_reflection_align_corners) { OpTester test("GridSample", 16); std::string mode = "bicubic"; std::string padding_mode = "reflection"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-1.084082f, -0.128738f, -0.681077f, -1.309896f, 0.660269f, -1.412063f, 1.834581f, 0.456195f, 0.162801f, -0.638266f, 0.897973f, -0.383653f, 0.297945f, 1.809414f, -0.091298f, 1.092744f, -0.102453f, -1.726535f, -0.484632f, 0.712097f, 1.820312f, -0.852073f, -0.341399f, -0.138106f}; + std::initializer_list X_data{TypeParam(-1.084082f), TypeParam(-0.128738f), TypeParam(-0.681077f), TypeParam(-1.309896f), TypeParam(0.660269f), TypeParam(-1.412063f), TypeParam(1.834581f), TypeParam(0.456195f), TypeParam(0.162801f), TypeParam(-0.638266f), TypeParam(0.897973f), TypeParam(-0.383653f), TypeParam(0.297945f), TypeParam(1.809414f), TypeParam(-0.091298f), TypeParam(1.092744f), TypeParam(-0.102453f), TypeParam(-1.726535f), TypeParam(-0.484632f), TypeParam(0.712097f), TypeParam(1.820312f), TypeParam(-0.852073f), TypeParam(-0.341399f), TypeParam(-0.138106f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{-0.501236f, -0.770480f, -0.140656f, -1.129896f, 0.470370f, 0.885106f, 0.288068f, -0.118568f, 0.594968f, -0.761702f, 1.173892f, -1.193212f, -1.149534f, -0.283562f, 0.980213f, 0.120151f, 0.460855f, -0.879608f, 0.437623f, -0.134092f, 0.480988f, 0.847491f, 0.521616f, -0.102077f}; + std::initializer_list Grid_data{TypeParam(-0.501236f), TypeParam(-0.770480f), TypeParam(-0.140656f), TypeParam(-1.129896f), TypeParam(0.470370f), TypeParam(0.885106f), TypeParam(0.288068f), TypeParam(-0.118568f), TypeParam(0.594968f), TypeParam(-0.761702f), TypeParam(1.173892f), TypeParam(-1.193212f), TypeParam(-1.149534f), TypeParam(-0.283562f), TypeParam(0.980213f), TypeParam(0.120151f), TypeParam(0.460855f), TypeParam(-0.879608f), TypeParam(0.437623f), TypeParam(-0.134092f), TypeParam(0.480988f), TypeParam(0.847491f), TypeParam(0.521616f), TypeParam(-0.102077f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-0.953278f, -0.722872f, -1.065112f, -1.071529f, -0.344328f, -0.233562f, 1.436462f, 1.232983f, -0.181487f, -0.297043f, 0.464837f, 0.396673f, 0.053896f, 0.733510f, 1.541248f, 1.117701f, -1.352406f, 1.131762f, 1.324986f, -0.882173f, 0.469635f, -0.247133f, -0.196824f, -0.393592f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.953278f), TypeParam(-0.722872f), TypeParam(-1.065112f), TypeParam(-1.071529f), TypeParam(-0.344328f), TypeParam(-0.233562f), TypeParam(1.436462f), TypeParam(1.232983f), TypeParam(-0.181487f), TypeParam(-0.297043f), TypeParam(0.464837f), TypeParam(0.396673f), TypeParam(0.053896f), TypeParam(0.733510f), TypeParam(1.541248f), TypeParam(1.117701f), TypeParam(-1.352406f), TypeParam(1.131762f), TypeParam(1.324986f), TypeParam(-0.882173f), TypeParam(0.469635f), TypeParam(-0.247133f), TypeParam(-0.196824f), TypeParam(-0.393592f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_reflection_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_reflection_no_align_corners) { OpTester test("GridSample", 16); std::string mode = "bicubic"; std::string padding_mode = "reflection"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-1.122981f, 0.620969f, -0.876394f, -1.774003f, -0.810376f, -1.475962f, 0.667025f, 0.668804f, -0.748346f, 1.422400f, 0.138469f, -0.165945f, 1.266886f, -0.496157f, 0.158060f, 0.488900f, 0.414476f, 0.419527f, 0.238000f, -0.034674f, 0.229435f, 0.234530f, 0.320846f, 0.703888f}; + std::initializer_list X_data{TypeParam(-1.122981f), TypeParam(0.620969f), TypeParam(-0.876394f), TypeParam(-1.774003f), TypeParam(-0.810376f), TypeParam(-1.475962f), TypeParam(0.667025f), TypeParam(0.668804f), TypeParam(-0.748346f), TypeParam(1.422400f), TypeParam(0.138469f), TypeParam(-0.165945f), TypeParam(1.266886f), TypeParam(-0.496157f), TypeParam(0.158060f), TypeParam(0.488900f), TypeParam(0.414476f), TypeParam(0.419527f), TypeParam(0.238000f), TypeParam(-0.034674f), TypeParam(0.229435f), TypeParam(0.234530f), TypeParam(0.320846f), TypeParam(0.703888f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.471637f, -0.923628f, -0.909401f, 0.684338f, 0.224360f, 1.092855f, -0.320755f, -0.579618f, -0.111056f, 0.006071f, 0.915173f, -1.195296f, -0.085441f, 0.530823f, -0.660820f, -0.609769f, 0.579921f, -1.149822f, 0.284347f, -0.929024f, 0.596474f, -1.026049f, 0.737766f, -1.135959f}; + std::initializer_list Grid_data{TypeParam(0.471637f), TypeParam(-0.923628f), TypeParam(-0.909401f), TypeParam(0.684338f), TypeParam(0.224360f), TypeParam(1.092855f), TypeParam(-0.320755f), TypeParam(-0.579618f), TypeParam(-0.111056f), TypeParam(0.006071f), TypeParam(0.915173f), TypeParam(-1.195296f), TypeParam(-0.085441f), TypeParam(0.530823f), TypeParam(-0.660820f), TypeParam(-0.609769f), TypeParam(0.579921f), TypeParam(-1.149822f), TypeParam(0.284347f), TypeParam(-0.929024f), TypeParam(0.596474f), TypeParam(-1.026049f), TypeParam(0.737766f), TypeParam(-1.135959f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{0.998063f, -0.689213f, -1.266024f, -0.870706f, -1.217616f, 1.292693f, 0.543307f, 0.219521f, -0.255151f, 0.543599f, 0.062982f, 0.527696f, 0.387590f, 1.352544f, -0.758053f, -0.262859f, -0.820496f, -0.934255f, 0.434353f, 0.262797f, -0.092283f, -0.021089f, -0.106052f, -0.119717f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(0.998063f), TypeParam(-0.689213f), TypeParam(-1.266024f), TypeParam(-0.870706f), TypeParam(-1.217616f), TypeParam(1.292693f), TypeParam(0.543307f), TypeParam(0.219521f), TypeParam(-0.255151f), TypeParam(0.543599f), TypeParam(0.062982f), TypeParam(0.527696f), TypeParam(0.387590f), TypeParam(1.352544f), TypeParam(-0.758053f), TypeParam(-0.262859f), TypeParam(-0.820496f), TypeParam(-0.934255f), TypeParam(0.434353f), TypeParam(0.262797f), TypeParam(-0.092283f), TypeParam(-0.021089f), TypeParam(-0.106052f), TypeParam(-0.119717f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(16)); } -TEST(GridSampleTest, test_grid_sample_20_4D_nearest_zeros_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_nearest_zeros_align_corners) { OpTester test("GridSample", 20); std::string mode = "nearest"; std::string padding_mode = "zeros"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{0.404710f, -0.654932f, 0.052124f, 0.340055f, -0.212416f, 1.562917f, -0.907159f, -1.566185f, 0.596746f, 1.002548f, -0.820504f, 0.509186f, 0.951389f, 0.773736f, -2.144711f, 0.044147f, 1.290612f, 0.664926f, 0.530731f, -0.423196f, -0.388699f, 0.333224f, 0.293744f, -0.157543f}; + std::initializer_list X_data{TypeParam(0.404710f), TypeParam(-0.654932f), TypeParam(0.052124f), TypeParam(0.340055f), TypeParam(-0.212416f), TypeParam(1.562917f), TypeParam(-0.907159f), TypeParam(-1.566185f), TypeParam(0.596746f), TypeParam(1.002548f), TypeParam(-0.820504f), TypeParam(0.509186f), TypeParam(0.951389f), TypeParam(0.773736f), TypeParam(-2.144711f), TypeParam(0.044147f), TypeParam(1.290612f), TypeParam(0.664926f), TypeParam(0.530731f), TypeParam(-0.423196f), TypeParam(-0.388699f), TypeParam(0.333224f), TypeParam(0.293744f), TypeParam(-0.157543f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.528957f, 0.982925f, -0.033286f, -0.806271f, 0.793837f, -0.411498f, 0.621343f, -0.295724f, 0.510113f, 1.079311f, 1.115827f, -1.092078f, -0.793776f, -0.496160f, -0.765241f, 1.151400f, -0.105983f, -0.796009f, -0.533987f, -0.662838f, 0.489587f, -1.046701f, -1.118884f, -1.182913f}; + std::initializer_list Grid_data{TypeParam(0.528957f), TypeParam(0.982925f), TypeParam(-0.033286f), TypeParam(-0.806271f), TypeParam(0.793837f), TypeParam(-0.411498f), TypeParam(0.621343f), TypeParam(-0.295724f), TypeParam(0.510113f), TypeParam(1.079311f), TypeParam(1.115827f), TypeParam(-1.092078f), TypeParam(-0.793776f), TypeParam(-0.496160f), TypeParam(-0.765241f), TypeParam(1.151400f), TypeParam(-0.105983f), TypeParam(-0.796009f), TypeParam(-0.533987f), TypeParam(-0.662838f), TypeParam(0.489587f), TypeParam(-1.046701f), TypeParam(-1.118884f), TypeParam(-1.182913f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{1.562917f, 0.404710f, 0.340055f, 0.340055f, 1.562917f, -0.654932f, 0.509186f, -0.907159f, 1.002548f, 1.002548f, 0.509186f, -1.566185f, -2.144711f, 1.290612f, 0.951389f, 0.951389f, 0.773736f, 0.951389f, -0.388699f, 0.293744f, 0.530731f, 0.530731f, -0.423196f, 0.530731f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(1.562917f), TypeParam(0.404710f), TypeParam(0.340055f), TypeParam(0.340055f), TypeParam(1.562917f), TypeParam(-0.654932f), TypeParam(0.509186f), TypeParam(-0.907159f), TypeParam(1.002548f), TypeParam(1.002548f), TypeParam(0.509186f), TypeParam(-1.566185f), TypeParam(-2.144711f), TypeParam(1.290612f), TypeParam(0.951389f), TypeParam(0.951389f), TypeParam(0.773736f), TypeParam(0.951389f), TypeParam(-0.388699f), TypeParam(0.293744f), TypeParam(0.530731f), TypeParam(0.530731f), TypeParam(-0.423196f), TypeParam(0.530731f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_5D_nearest_zeros_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_nearest_zeros_align_corners) { OpTester test("GridSample", 20); std::string mode = "nearest"; std::string padding_mode = "zeros"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 3, 2}; - std::initializer_list X_data{-1.495959f, 0.018231f, 0.345600f, 0.031206f, 0.400390f, 0.425763f, 0.839517f, 1.238945f, 0.523906f, -1.658372f, 0.548335f, -1.398321f, -1.976414f, 1.232491f, -0.545575f, -0.069414f, 0.732245f, -0.150333f, -0.707132f, 0.467497f, 0.278677f, 1.335679f, 1.155313f, -0.056298f, 0.430615f, -0.932645f, -1.505319f, 0.103317f, 1.521579f, 0.365497f, 1.428928f, 0.364333f, 1.683777f, 1.010632f, 0.621895f, 2.284701f, 1.574905f, -0.310514f, 1.495724f, 1.003370f, -1.437482f, 0.043097f, -1.645546f, -1.464643f, 0.350139f, -0.105905f, -0.740495f, 1.157691f, 1.443377f, 0.198399f, -1.105180f, -2.037115f, 2.128767f, -0.204457f, 0.468464f, 1.203629f, -0.362309f, -0.130520f, 1.532353f, 1.547599f, -0.831847f, -1.008509f, 0.023218f, 0.342626f, -0.882915f, 0.560640f, -1.142297f, 1.119107f, 0.385787f, -0.068515f, -0.529550f, -0.233903f}; + std::initializer_list X_data{TypeParam(-1.495959f), TypeParam(0.018231f), TypeParam(0.345600f), TypeParam(0.031206f), TypeParam(0.400390f), TypeParam(0.425763f), TypeParam(0.839517f), TypeParam(1.238945f), TypeParam(0.523906f), TypeParam(-1.658372f), TypeParam(0.548335f), TypeParam(-1.398321f), TypeParam(-1.976414f), TypeParam(1.232491f), TypeParam(-0.545575f), TypeParam(-0.069414f), TypeParam(0.732245f), TypeParam(-0.150333f), TypeParam(-0.707132f), TypeParam(0.467497f), TypeParam(0.278677f), TypeParam(1.335679f), TypeParam(1.155313f), TypeParam(-0.056298f), TypeParam(0.430615f), TypeParam(-0.932645f), TypeParam(-1.505319f), TypeParam(0.103317f), TypeParam(1.521579f), TypeParam(0.365497f), TypeParam(1.428928f), TypeParam(0.364333f), TypeParam(1.683777f), TypeParam(1.010632f), TypeParam(0.621895f), TypeParam(2.284701f), TypeParam(1.574905f), TypeParam(-0.310514f), TypeParam(1.495724f), TypeParam(1.003370f), TypeParam(-1.437482f), TypeParam(0.043097f), TypeParam(-1.645546f), TypeParam(-1.464643f), TypeParam(0.350139f), TypeParam(-0.105905f), TypeParam(-0.740495f), TypeParam(1.157691f), TypeParam(1.443377f), TypeParam(0.198399f), TypeParam(-1.105180f), TypeParam(-2.037115f), TypeParam(2.128767f), TypeParam(-0.204457f), TypeParam(0.468464f), TypeParam(1.203629f), TypeParam(-0.362309f), TypeParam(-0.130520f), TypeParam(1.532353f), TypeParam(1.547599f), TypeParam(-0.831847f), TypeParam(-1.008509f), TypeParam(0.023218f), TypeParam(0.342626f), TypeParam(-0.882915f), TypeParam(0.560640f), TypeParam(-1.142297f), TypeParam(1.119107f), TypeParam(0.385787f), TypeParam(-0.068515f), TypeParam(-0.529550f), TypeParam(-0.233903f)}; std::initializer_list Grid_shape{2, 3, 3, 2, 3}; - std::initializer_list Grid_data{0.812645f, 0.528235f, -0.550793f, -0.856977f, -1.073535f, 0.059526f, 1.163856f, -0.227931f, -0.050518f, -0.872033f, 0.368412f, 0.760780f, -1.183099f, -0.844947f, 0.888849f, 0.284117f, -0.074815f, 0.214510f, -0.182450f, -0.838758f, -1.121316f, 0.789250f, -0.142724f, -0.445665f, -0.309738f, -0.654508f, -0.355420f, -1.030097f, 0.898012f, 0.490011f, -0.605186f, -0.409576f, 0.538365f, -0.444367f, 0.316432f, 0.330410f, -0.755392f, 0.300602f, 0.073421f, 1.048061f, -0.434184f, -0.308482f, 1.033921f, -0.979923f, 0.086698f, 1.156203f, -0.538042f, 1.150419f, 1.064809f, 1.116408f, -0.114508f, 1.085560f, -0.522863f, -0.410766f, 0.453879f, 0.253497f, 0.661531f, 1.140383f, -0.751187f, 0.636872f, 0.401477f, 0.633082f, 0.569007f, -0.448884f, -0.948427f, 0.960462f, -0.684283f, 0.767193f, -1.143172f, -0.207603f, 0.012719f, 0.207628f, 0.096998f, 0.378128f, -0.133613f, 0.293885f, 1.187501f, -0.776462f, -0.065516f, -0.458068f, 1.052916f, 1.027248f, -0.032723f, -0.415959f, -0.741439f, 0.858648f, -0.082636f, 1.130172f, 0.684314f, 1.050365f, 0.949108f, -0.779811f, 0.351243f, -0.497591f, 0.602104f, -0.107892f, 0.103884f, -0.829931f, -1.072471f, 0.451888f, 0.278862f, 0.104235f, 0.815033f, -0.501089f, 0.425977f, -0.660914f, 0.248640f, -0.273958f}; + std::initializer_list Grid_data{TypeParam(0.812645f), TypeParam(0.528235f), TypeParam(-0.550793f), TypeParam(-0.856977f), TypeParam(-1.073535f), TypeParam(0.059526f), TypeParam(1.163856f), TypeParam(-0.227931f), TypeParam(-0.050518f), TypeParam(-0.872033f), TypeParam(0.368412f), TypeParam(0.760780f), TypeParam(-1.183099f), TypeParam(-0.844947f), TypeParam(0.888849f), TypeParam(0.284117f), TypeParam(-0.074815f), TypeParam(0.214510f), TypeParam(-0.182450f), TypeParam(-0.838758f), TypeParam(-1.121316f), TypeParam(0.789250f), TypeParam(-0.142724f), TypeParam(-0.445665f), TypeParam(-0.309738f), TypeParam(-0.654508f), TypeParam(-0.355420f), TypeParam(-1.030097f), TypeParam(0.898012f), TypeParam(0.490011f), TypeParam(-0.605186f), TypeParam(-0.409576f), TypeParam(0.538365f), TypeParam(-0.444367f), TypeParam(0.316432f), TypeParam(0.330410f), TypeParam(-0.755392f), TypeParam(0.300602f), TypeParam(0.073421f), TypeParam(1.048061f), TypeParam(-0.434184f), TypeParam(-0.308482f), TypeParam(1.033921f), TypeParam(-0.979923f), TypeParam(0.086698f), TypeParam(1.156203f), TypeParam(-0.538042f), TypeParam(1.150419f), TypeParam(1.064809f), TypeParam(1.116408f), TypeParam(-0.114508f), TypeParam(1.085560f), TypeParam(-0.522863f), TypeParam(-0.410766f), TypeParam(0.453879f), TypeParam(0.253497f), TypeParam(0.661531f), TypeParam(1.140383f), TypeParam(-0.751187f), TypeParam(0.636872f), TypeParam(0.401477f), TypeParam(0.633082f), TypeParam(0.569007f), TypeParam(-0.448884f), TypeParam(-0.948427f), TypeParam(0.960462f), TypeParam(-0.684283f), TypeParam(0.767193f), TypeParam(-1.143172f), TypeParam(-0.207603f), TypeParam(0.012719f), TypeParam(0.207628f), TypeParam(0.096998f), TypeParam(0.378128f), TypeParam(-0.133613f), TypeParam(0.293885f), TypeParam(1.187501f), TypeParam(-0.776462f), TypeParam(-0.065516f), TypeParam(-0.458068f), TypeParam(1.052916f), TypeParam(1.027248f), TypeParam(-0.032723f), TypeParam(-0.415959f), TypeParam(-0.741439f), TypeParam(0.858648f), TypeParam(-0.082636f), TypeParam(1.130172f), TypeParam(0.684314f), TypeParam(1.050365f), TypeParam(0.949108f), TypeParam(-0.779811f), TypeParam(0.351243f), TypeParam(-0.497591f), TypeParam(0.602104f), TypeParam(-0.107892f), TypeParam(0.103884f), TypeParam(-0.829931f), TypeParam(-1.072471f), TypeParam(0.451888f), TypeParam(0.278862f), TypeParam(0.104235f), TypeParam(0.815033f), TypeParam(-0.501089f), TypeParam(0.425977f), TypeParam(-0.660914f), TypeParam(0.248640f), TypeParam(-0.273958f)}; std::initializer_list Y_shape{2, 2, 3, 3, 2}; - std::initializer_list Y_data{0.425763f, 0.839517f, -1.658372f, -0.545575f, -1.976414f, -1.658372f, -1.495959f, -1.658372f, 0.839517f, 0.548335f, -0.545575f, 0.523906f, 0.523906f, -1.658372f, 1.238945f, 1.232491f, -1.398321f, 1.238945f, -0.056298f, 0.430615f, 0.103317f, 1.683777f, 1.428928f, 0.103317f, -0.707132f, 0.103317f, 0.430615f, 1.521579f, 1.683777f, -1.505319f, -1.505319f, 0.103317f, -0.932645f, 0.364333f, 0.365497f, -0.932645f, -2.037115f, 0.198399f, -0.204457f, 1.443377f, -1.437482f, 0.350139f, -0.105905f, 0.043097f, -1.105180f, -0.105905f, -0.740495f, -0.204457f, -1.464643f, -0.740495f, -0.310514f, -0.105905f, -1.464643f, 0.350139f, -0.068515f, 1.119107f, -0.233903f, -1.142297f, 1.532353f, 0.023218f, 0.342626f, 1.547599f, 0.385787f, 0.342626f, -0.882915f, -0.233903f, -1.008509f, -0.882915f, 1.203629f, 0.342626f, -1.008509f, 0.023218f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(0.425763f), TypeParam(0.839517f), TypeParam(-1.658372f), TypeParam(-0.545575f), TypeParam(-1.976414f), TypeParam(-1.658372f), TypeParam(-1.495959f), TypeParam(-1.658372f), TypeParam(0.839517f), TypeParam(0.548335f), TypeParam(-0.545575f), TypeParam(0.523906f), TypeParam(0.523906f), TypeParam(-1.658372f), TypeParam(1.238945f), TypeParam(1.232491f), TypeParam(-1.398321f), TypeParam(1.238945f), TypeParam(-0.056298f), TypeParam(0.430615f), TypeParam(0.103317f), TypeParam(1.683777f), TypeParam(1.428928f), TypeParam(0.103317f), TypeParam(-0.707132f), TypeParam(0.103317f), TypeParam(0.430615f), TypeParam(1.521579f), TypeParam(1.683777f), TypeParam(-1.505319f), TypeParam(-1.505319f), TypeParam(0.103317f), TypeParam(-0.932645f), TypeParam(0.364333f), TypeParam(0.365497f), TypeParam(-0.932645f), TypeParam(-2.037115f), TypeParam(0.198399f), TypeParam(-0.204457f), TypeParam(1.443377f), TypeParam(-1.437482f), TypeParam(0.350139f), TypeParam(-0.105905f), TypeParam(0.043097f), TypeParam(-1.105180f), TypeParam(-0.105905f), TypeParam(-0.740495f), TypeParam(-0.204457f), TypeParam(-1.464643f), TypeParam(-0.740495f), TypeParam(-0.310514f), TypeParam(-0.105905f), TypeParam(-1.464643f), TypeParam(0.350139f), TypeParam(-0.068515f), TypeParam(1.119107f), TypeParam(-0.233903f), TypeParam(-1.142297f), TypeParam(1.532353f), TypeParam(0.023218f), TypeParam(0.342626f), TypeParam(1.547599f), TypeParam(0.385787f), TypeParam(0.342626f), TypeParam(-0.882915f), TypeParam(-0.233903f), TypeParam(-1.008509f), TypeParam(-0.882915f), TypeParam(1.203629f), TypeParam(0.342626f), TypeParam(-1.008509f), TypeParam(0.023218f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_4D_nearest_zeros_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_nearest_zeros_no_align_corners) { OpTester test("GridSample", 20); std::string mode = "nearest"; std::string padding_mode = "zeros"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-1.948141f, 1.836740f, -0.418393f, -0.125621f, 1.779137f, -0.028049f, 0.367697f, -0.388847f, -0.939514f, -0.129193f, -0.101240f, -3.087570f, -0.778617f, 1.026859f, 0.624162f, 0.291416f, 0.580998f, -0.185200f, 0.333020f, 0.415896f, 0.011702f, 0.014502f, -0.722870f, -0.201041f}; + std::initializer_list X_data{TypeParam(-1.948141f), TypeParam(1.836740f), TypeParam(-0.418393f), TypeParam(-0.125621f), TypeParam(1.779137f), TypeParam(-0.028049f), TypeParam(0.367697f), TypeParam(-0.388847f), TypeParam(-0.939514f), TypeParam(-0.129193f), TypeParam(-0.101240f), TypeParam(-3.087570f), TypeParam(-0.778617f), TypeParam(1.026859f), TypeParam(0.624162f), TypeParam(0.291416f), TypeParam(0.580998f), TypeParam(-0.185200f), TypeParam(0.333020f), TypeParam(0.415896f), TypeParam(0.011702f), TypeParam(0.014502f), TypeParam(-0.722870f), TypeParam(-0.201041f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.818167f, -0.394078f, 0.627076f, -1.124307f, -0.296864f, -0.244061f, -0.423780f, 0.504000f, -0.546789f, -0.139085f, -0.346504f, -1.126900f, -0.198169f, -1.016972f, 0.699725f, 0.641356f, 1.124151f, -0.402963f, 0.061023f, 0.235069f, 1.197862f, 1.099936f, -0.621047f, -1.021083f}; + std::initializer_list Grid_data{TypeParam(0.818167f), TypeParam(-0.394078f), TypeParam(0.627076f), TypeParam(-1.124307f), TypeParam(-0.296864f), TypeParam(-0.244061f), TypeParam(-0.423780f), TypeParam(0.504000f), TypeParam(-0.546789f), TypeParam(-0.139085f), TypeParam(-0.346504f), TypeParam(-1.126900f), TypeParam(-0.198169f), TypeParam(-1.016972f), TypeParam(0.699725f), TypeParam(0.641356f), TypeParam(1.124151f), TypeParam(-0.402963f), TypeParam(0.061023f), TypeParam(0.235069f), TypeParam(1.197862f), TypeParam(1.099936f), TypeParam(-0.621047f), TypeParam(-1.021083f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{1.836740f, 0.000000f, -0.418393f, 1.779137f, -0.418393f, 0.000000f, -0.388847f, 0.000000f, -0.939514f, -0.101240f, -0.939514f, 0.000000f, 0.000000f, -0.185200f, 0.000000f, 0.291416f, 0.000000f, 0.000000f, 0.000000f, -0.201041f, 0.000000f, 0.014502f, 0.000000f, 0.000000f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(1.836740f), TypeParam(0.000000f), TypeParam(-0.418393f), TypeParam(1.779137f), TypeParam(-0.418393f), TypeParam(0.000000f), TypeParam(-0.388847f), TypeParam(0.000000f), TypeParam(-0.939514f), TypeParam(-0.101240f), TypeParam(-0.939514f), TypeParam(0.000000f), TypeParam(0.000000f), TypeParam(-0.185200f), TypeParam(0.000000f), TypeParam(0.291416f), TypeParam(0.000000f), TypeParam(0.000000f), TypeParam(0.000000f), TypeParam(-0.201041f), TypeParam(0.000000f), TypeParam(0.014502f), TypeParam(0.000000f), TypeParam(0.000000f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_5D_nearest_zeros_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_nearest_zeros_no_align_corners) { OpTester test("GridSample", 20); std::string mode = "nearest"; std::string padding_mode = "zeros"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 3, 2}; - std::initializer_list X_data{0.317302f, 0.629807f, -0.470444f, 0.215051f, 2.234212f, -1.940229f, 0.577203f, -0.166697f, -0.023467f, -0.451050f, -2.199999f, 1.469197f, -1.758133f, -0.570410f, -1.040355f, -0.627640f, 1.398573f, 0.275127f, -0.333592f, -0.677762f, -0.247167f, -0.290725f, -0.986956f, 0.173983f, -0.971920f, 0.225261f, -0.626680f, 1.660835f, 0.972993f, 0.223424f, 2.283593f, -1.145964f, -0.851223f, -2.052948f, -1.351783f, -0.028922f, 0.394421f, 0.057878f, -0.668671f, -0.088841f, 0.560186f, -0.105506f, 0.277478f, 1.047901f, -0.564728f, -0.287761f, 0.653621f, 0.259766f, 1.629452f, -2.337903f, -0.276703f, 0.258084f, -0.552200f, -0.464470f, -0.412042f, -1.047346f, 0.169468f, 1.334588f, 0.580615f, 1.217562f, -2.487876f, -1.218598f, -0.256617f, 1.397251f, 0.694875f, 0.732315f, 0.574448f, 0.673838f, -1.870634f, -0.855206f, 1.068415f, 0.096061f}; + std::initializer_list X_data{TypeParam(0.317302f), TypeParam(0.629807f), TypeParam(-0.470444f), TypeParam(0.215051f), TypeParam(2.234212f), TypeParam(-1.940229f), TypeParam(0.577203f), TypeParam(-0.166697f), TypeParam(-0.023467f), TypeParam(-0.451050f), TypeParam(-2.199999f), TypeParam(1.469197f), TypeParam(-1.758133f), TypeParam(-0.570410f), TypeParam(-1.040355f), TypeParam(-0.627640f), TypeParam(1.398573f), TypeParam(0.275127f), TypeParam(-0.333592f), TypeParam(-0.677762f), TypeParam(-0.247167f), TypeParam(-0.290725f), TypeParam(-0.986956f), TypeParam(0.173983f), TypeParam(-0.971920f), TypeParam(0.225261f), TypeParam(-0.626680f), TypeParam(1.660835f), TypeParam(0.972993f), TypeParam(0.223424f), TypeParam(2.283593f), TypeParam(-1.145964f), TypeParam(-0.851223f), TypeParam(-2.052948f), TypeParam(-1.351783f), TypeParam(-0.028922f), TypeParam(0.394421f), TypeParam(0.057878f), TypeParam(-0.668671f), TypeParam(-0.088841f), TypeParam(0.560186f), TypeParam(-0.105506f), TypeParam(0.277478f), TypeParam(1.047901f), TypeParam(-0.564728f), TypeParam(-0.287761f), TypeParam(0.653621f), TypeParam(0.259766f), TypeParam(1.629452f), TypeParam(-2.337903f), TypeParam(-0.276703f), TypeParam(0.258084f), TypeParam(-0.552200f), TypeParam(-0.464470f), TypeParam(-0.412042f), TypeParam(-1.047346f), TypeParam(0.169468f), TypeParam(1.334588f), TypeParam(0.580615f), TypeParam(1.217562f), TypeParam(-2.487876f), TypeParam(-1.218598f), TypeParam(-0.256617f), TypeParam(1.397251f), TypeParam(0.694875f), TypeParam(0.732315f), TypeParam(0.574448f), TypeParam(0.673838f), TypeParam(-1.870634f), TypeParam(-0.855206f), TypeParam(1.068415f), TypeParam(0.096061f)}; std::initializer_list Grid_shape{2, 3, 3, 2, 3}; - std::initializer_list Grid_data{0.650046f, -0.680891f, -0.200337f, -1.006178f, -0.676990f, 0.500592f, -1.118072f, -0.684288f, 0.899676f, -0.615418f, -0.499387f, -0.336929f, 0.512951f, -0.787164f, 0.120318f, 0.490083f, -0.087112f, 0.216982f, -0.915417f, 0.542519f, 0.448475f, -0.150519f, -0.992244f, 0.479971f, 0.783050f, -0.209890f, 0.565605f, 0.444791f, -0.479961f, -0.083304f, 1.194526f, 0.005665f, -0.955336f, -0.087514f, 0.596991f, -0.391708f, -0.628420f, 0.988534f, 0.634814f, -0.203871f, 0.061307f, -0.126915f, 0.278599f, 0.042647f, -0.726162f, 0.222329f, 0.031386f, 0.077584f, -0.457305f, 0.307467f, -0.970375f, 0.358708f, 0.650272f, -0.132064f, -0.932160f, -0.004362f, 0.001704f, -1.037046f, -0.848754f, 1.109926f, 0.897382f, 0.665044f, 0.831311f, 0.461956f, 0.675346f, 0.794786f, -0.280329f, -0.152546f, 0.855656f, -0.000432f, -0.780824f, -0.930479f, 0.671131f, 0.993983f, 0.931935f, 0.199703f, 0.828337f, -1.101760f, -0.864556f, -1.154677f, 0.966824f, -0.010858f, -0.552558f, 0.406048f, -0.449199f, -0.769613f, 0.462838f, 0.219719f, -0.859342f, -0.790394f, 0.562644f, 0.912452f, 0.097688f, -0.602742f, 0.579449f, 0.209287f, -1.050575f, -0.777654f, 0.262652f, 0.742529f, -0.385517f, 0.580240f, -0.743175f, 1.148320f, 0.855053f, 0.224769f, 0.533871f, 0.417788f}; + std::initializer_list Grid_data{TypeParam(0.650046f), TypeParam(-0.680891f), TypeParam(-0.200337f), TypeParam(-1.006178f), TypeParam(-0.676990f), TypeParam(0.500592f), TypeParam(-1.118072f), TypeParam(-0.684288f), TypeParam(0.899676f), TypeParam(-0.615418f), TypeParam(-0.499387f), TypeParam(-0.336929f), TypeParam(0.512951f), TypeParam(-0.787164f), TypeParam(0.120318f), TypeParam(0.490083f), TypeParam(-0.087112f), TypeParam(0.216982f), TypeParam(-0.915417f), TypeParam(0.542519f), TypeParam(0.448475f), TypeParam(-0.150519f), TypeParam(-0.992244f), TypeParam(0.479971f), TypeParam(0.783050f), TypeParam(-0.209890f), TypeParam(0.565605f), TypeParam(0.444791f), TypeParam(-0.479961f), TypeParam(-0.083304f), TypeParam(1.194526f), TypeParam(0.005665f), TypeParam(-0.955336f), TypeParam(-0.087514f), TypeParam(0.596991f), TypeParam(-0.391708f), TypeParam(-0.628420f), TypeParam(0.988534f), TypeParam(0.634814f), TypeParam(-0.203871f), TypeParam(0.061307f), TypeParam(-0.126915f), TypeParam(0.278599f), TypeParam(0.042647f), TypeParam(-0.726162f), TypeParam(0.222329f), TypeParam(0.031386f), TypeParam(0.077584f), TypeParam(-0.457305f), TypeParam(0.307467f), TypeParam(-0.970375f), TypeParam(0.358708f), TypeParam(0.650272f), TypeParam(-0.132064f), TypeParam(-0.932160f), TypeParam(-0.004362f), TypeParam(0.001704f), TypeParam(-1.037046f), TypeParam(-0.848754f), TypeParam(1.109926f), TypeParam(0.897382f), TypeParam(0.665044f), TypeParam(0.831311f), TypeParam(0.461956f), TypeParam(0.675346f), TypeParam(0.794786f), TypeParam(-0.280329f), TypeParam(-0.152546f), TypeParam(0.855656f), TypeParam(-0.000432f), TypeParam(-0.780824f), TypeParam(-0.930479f), TypeParam(0.671131f), TypeParam(0.993983f), TypeParam(0.931935f), TypeParam(0.199703f), TypeParam(0.828337f), TypeParam(-1.101760f), TypeParam(-0.864556f), TypeParam(-1.154677f), TypeParam(0.966824f), TypeParam(-0.010858f), TypeParam(-0.552558f), TypeParam(0.406048f), TypeParam(-0.449199f), TypeParam(-0.769613f), TypeParam(0.462838f), TypeParam(0.219719f), TypeParam(-0.859342f), TypeParam(-0.790394f), TypeParam(0.562644f), TypeParam(0.912452f), TypeParam(0.097688f), TypeParam(-0.602742f), TypeParam(0.579449f), TypeParam(0.209287f), TypeParam(-1.050575f), TypeParam(-0.777654f), TypeParam(0.262652f), TypeParam(0.742529f), TypeParam(-0.385517f), TypeParam(0.580240f), TypeParam(-0.743175f), TypeParam(1.148320f), TypeParam(0.855053f), TypeParam(0.224769f), TypeParam(0.533871f), TypeParam(0.417788f)}; std::initializer_list Y_shape{2, 2, 3, 3, 2}; - std::initializer_list Y_data{-0.166697f, 0.000000f, 0.000000f, 0.317302f, -0.166697f, -0.451050f, 1.398573f, -1.758133f, -0.627640f, -0.166697f, 0.000000f, 2.234212f, 1.398573f, -0.023467f, 0.215051f, -0.451050f, -0.470444f, 1.469197f, 0.225261f, 0.000000f, 0.000000f, -0.333592f, 0.225261f, 1.660835f, -1.351783f, 2.283593f, -2.052948f, 0.225261f, 0.000000f, -0.986956f, -1.351783f, -0.626680f, -0.290725f, 1.660835f, -0.247167f, 0.223424f, -0.564728f, 0.000000f, -0.464470f, -0.464470f, -0.276703f, 0.394421f, -0.464470f, 0.000000f, 0.000000f, 1.629452f, 1.629452f, 0.057878f, 0.259766f, 0.653621f, 0.000000f, -2.337903f, 0.000000f, -0.464470f, -0.256617f, 0.000000f, 0.096061f, 0.096061f, -1.870634f, -0.412042f, 0.096061f, 0.000000f, 0.000000f, 0.574448f, 0.574448f, -1.047346f, 0.732315f, 0.694875f, 0.000000f, 0.673838f, 0.000000f, 0.096061f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.166697f), TypeParam(0.000000f), TypeParam(0.000000f), TypeParam(0.317302f), TypeParam(-0.166697f), TypeParam(-0.451050f), TypeParam(1.398573f), TypeParam(-1.758133f), TypeParam(-0.627640f), TypeParam(-0.166697f), TypeParam(0.000000f), TypeParam(2.234212f), TypeParam(1.398573f), TypeParam(-0.023467f), TypeParam(0.215051f), TypeParam(-0.451050f), TypeParam(-0.470444f), TypeParam(1.469197f), TypeParam(0.225261f), TypeParam(0.000000f), TypeParam(0.000000f), TypeParam(-0.333592f), TypeParam(0.225261f), TypeParam(1.660835f), TypeParam(-1.351783f), TypeParam(2.283593f), TypeParam(-2.052948f), TypeParam(0.225261f), TypeParam(0.000000f), TypeParam(-0.986956f), TypeParam(-1.351783f), TypeParam(-0.626680f), TypeParam(-0.290725f), TypeParam(1.660835f), TypeParam(-0.247167f), TypeParam(0.223424f), TypeParam(-0.564728f), TypeParam(0.000000f), TypeParam(-0.464470f), TypeParam(-0.464470f), TypeParam(-0.276703f), TypeParam(0.394421f), TypeParam(-0.464470f), TypeParam(0.000000f), TypeParam(0.000000f), TypeParam(1.629452f), TypeParam(1.629452f), TypeParam(0.057878f), TypeParam(0.259766f), TypeParam(0.653621f), TypeParam(0.000000f), TypeParam(-2.337903f), TypeParam(0.000000f), TypeParam(-0.464470f), TypeParam(-0.256617f), TypeParam(0.000000f), TypeParam(0.096061f), TypeParam(0.096061f), TypeParam(-1.870634f), TypeParam(-0.412042f), TypeParam(0.096061f), TypeParam(0.000000f), TypeParam(0.000000f), TypeParam(0.574448f), TypeParam(0.574448f), TypeParam(-1.047346f), TypeParam(0.732315f), TypeParam(0.694875f), TypeParam(0.000000f), TypeParam(0.673838f), TypeParam(0.000000f), TypeParam(0.096061f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_4D_nearest_border_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_nearest_border_align_corners) { OpTester test("GridSample", 20); std::string mode = "nearest"; std::string padding_mode = "border"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{0.660065f, 0.995767f, -0.226389f, 0.590604f, -2.628610f, 0.444899f, 0.023282f, 0.024018f, -0.584701f, 1.988638f, -0.023379f, 0.711650f, -1.062933f, -0.064113f, 1.178346f, -0.652373f, 1.259795f, 1.508661f, -0.079368f, 0.819443f, 0.836356f, -0.362184f, -1.153828f, -0.561180f}; + std::initializer_list X_data{TypeParam(0.660065f), TypeParam(0.995767f), TypeParam(-0.226389f), TypeParam(0.590604f), TypeParam(-2.628610f), TypeParam(0.444899f), TypeParam(0.023282f), TypeParam(0.024018f), TypeParam(-0.584701f), TypeParam(1.988638f), TypeParam(-0.023379f), TypeParam(0.711650f), TypeParam(-1.062933f), TypeParam(-0.064113f), TypeParam(1.178346f), TypeParam(-0.652373f), TypeParam(1.259795f), TypeParam(1.508661f), TypeParam(-0.079368f), TypeParam(0.819443f), TypeParam(0.836356f), TypeParam(-0.362184f), TypeParam(-1.153828f), TypeParam(-0.561180f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{-0.447651f, -0.521958f, 0.673539f, 0.222645f, 1.010165f, 0.451903f, 0.966699f, -0.966970f, 0.964714f, -0.551345f, -0.321222f, 0.007182f, -0.225038f, 0.237367f, 1.069316f, -0.716982f, 0.370785f, -0.964445f, 0.188419f, 0.988574f, 0.809140f, 1.027635f, 0.649589f, -0.099282f}; + std::initializer_list Grid_data{TypeParam(-0.447651f), TypeParam(-0.521958f), TypeParam(0.673539f), TypeParam(0.222645f), TypeParam(1.010165f), TypeParam(0.451903f), TypeParam(0.966699f), TypeParam(-0.966970f), TypeParam(0.964714f), TypeParam(-0.551345f), TypeParam(-0.321222f), TypeParam(0.007182f), TypeParam(-0.225038f), TypeParam(0.237367f), TypeParam(1.069316f), TypeParam(-0.716982f), TypeParam(0.370785f), TypeParam(-0.964445f), TypeParam(0.188419f), TypeParam(0.988574f), TypeParam(0.809140f), TypeParam(1.027635f), TypeParam(0.649589f), TypeParam(-0.099282f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{0.660065f, 0.590604f, 0.590604f, 0.995767f, 0.995767f, -0.226389f, 0.023282f, 1.988638f, 1.988638f, 0.024018f, 0.024018f, -0.584701f, 1.178346f, -0.064113f, -0.064113f, 1.508661f, 1.508661f, -0.652373f, 0.836356f, 0.819443f, 0.819443f, -0.561180f, -0.561180f, -0.362184f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(0.660065f), TypeParam(0.590604f), TypeParam(0.590604f), TypeParam(0.995767f), TypeParam(0.995767f), TypeParam(-0.226389f), TypeParam(0.023282f), TypeParam(1.988638f), TypeParam(1.988638f), TypeParam(0.024018f), TypeParam(0.024018f), TypeParam(-0.584701f), TypeParam(1.178346f), TypeParam(-0.064113f), TypeParam(-0.064113f), TypeParam(1.508661f), TypeParam(1.508661f), TypeParam(-0.652373f), TypeParam(0.836356f), TypeParam(0.819443f), TypeParam(0.819443f), TypeParam(-0.561180f), TypeParam(-0.561180f), TypeParam(-0.362184f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_5D_nearest_border_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_nearest_border_align_corners) { OpTester test("GridSample", 20); std::string mode = "nearest"; std::string padding_mode = "border"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 3, 2}; - std::initializer_list X_data{-0.920922f, -0.560469f, -2.244605f, -0.061799f, 0.523656f, 0.110097f, -0.944521f, 0.818932f, 1.069286f, 0.611457f, -0.355875f, 1.664810f, 0.116694f, 2.318200f, 0.681699f, -0.792880f, -0.025672f, -0.592222f, 0.229768f, -0.521888f, 0.570937f, -0.029345f, -0.873323f, 1.721509f, 2.011626f, -0.310838f, 1.121670f, 0.778967f, -0.450894f, 1.030269f, 0.166967f, -0.244737f, 0.227200f, -0.416612f, -0.276513f, 0.714623f, 0.908783f, -1.393580f, -0.983675f, -0.366833f, 1.473970f, 0.624368f, -0.607720f, -0.523833f, -0.124702f, -0.766457f, -0.131027f, 2.227047f, 1.399269f, 0.053366f, -0.295771f, -0.283811f, 0.019280f, -0.104450f, -0.574185f, -2.130628f, 0.617878f, -1.728151f, -0.272528f, 1.299354f, -1.109310f, -1.881107f, -1.300843f, -0.765376f, -0.477722f, -1.230664f, -0.495792f, 1.061688f, 1.244247f, -0.550821f, -0.520524f, 1.541448f}; + std::initializer_list X_data{TypeParam(-0.920922f), TypeParam(-0.560469f), TypeParam(-2.244605f), TypeParam(-0.061799f), TypeParam(0.523656f), TypeParam(0.110097f), TypeParam(-0.944521f), TypeParam(0.818932f), TypeParam(1.069286f), TypeParam(0.611457f), TypeParam(-0.355875f), TypeParam(1.664810f), TypeParam(0.116694f), TypeParam(2.318200f), TypeParam(0.681699f), TypeParam(-0.792880f), TypeParam(-0.025672f), TypeParam(-0.592222f), TypeParam(0.229768f), TypeParam(-0.521888f), TypeParam(0.570937f), TypeParam(-0.029345f), TypeParam(-0.873323f), TypeParam(1.721509f), TypeParam(2.011626f), TypeParam(-0.310838f), TypeParam(1.121670f), TypeParam(0.778967f), TypeParam(-0.450894f), TypeParam(1.030269f), TypeParam(0.166967f), TypeParam(-0.244737f), TypeParam(0.227200f), TypeParam(-0.416612f), TypeParam(-0.276513f), TypeParam(0.714623f), TypeParam(0.908783f), TypeParam(-1.393580f), TypeParam(-0.983675f), TypeParam(-0.366833f), TypeParam(1.473970f), TypeParam(0.624368f), TypeParam(-0.607720f), TypeParam(-0.523833f), TypeParam(-0.124702f), TypeParam(-0.766457f), TypeParam(-0.131027f), TypeParam(2.227047f), TypeParam(1.399269f), TypeParam(0.053366f), TypeParam(-0.295771f), TypeParam(-0.283811f), TypeParam(0.019280f), TypeParam(-0.104450f), TypeParam(-0.574185f), TypeParam(-2.130628f), TypeParam(0.617878f), TypeParam(-1.728151f), TypeParam(-0.272528f), TypeParam(1.299354f), TypeParam(-1.109310f), TypeParam(-1.881107f), TypeParam(-1.300843f), TypeParam(-0.765376f), TypeParam(-0.477722f), TypeParam(-1.230664f), TypeParam(-0.495792f), TypeParam(1.061688f), TypeParam(1.244247f), TypeParam(-0.550821f), TypeParam(-0.520524f), TypeParam(1.541448f)}; std::initializer_list Grid_shape{2, 3, 3, 2, 3}; - std::initializer_list Grid_data{-1.189605f, -0.312072f, 0.459409f, 1.033285f, -1.083635f, 0.572921f, -1.138649f, -1.147562f, -0.751493f, -0.158500f, 0.335153f, -0.912613f, 0.924528f, 1.085165f, 0.073832f, 0.976781f, -0.543258f, -0.474714f, -0.154854f, 0.131118f, -0.837104f, -0.960885f, 0.474040f, 0.345992f, 1.173923f, -0.489256f, 0.423768f, -0.484246f, 0.592379f, -0.066474f, 0.889570f, 0.666682f, 0.998817f, 0.616675f, 0.045084f, 1.034127f, -0.704858f, 1.131824f, 1.172625f, 1.146321f, -0.560545f, -0.635830f, 0.075922f, 0.373677f, 0.601953f, 0.488043f, 1.021787f, -0.300648f, -0.393688f, 0.402240f, 0.334401f, -0.699993f, 0.116070f, -0.911100f, -0.352043f, -0.470968f, 1.051900f, -1.080208f, -0.708510f, -1.174356f, 0.302647f, -0.923627f, 0.388249f, -0.833533f, -0.768697f, -0.613051f, 0.180083f, 1.102657f, 1.124055f, -0.090660f, -1.175396f, -0.396450f, -0.457333f, -0.255235f, 0.458506f, 0.603882f, 0.532050f, 0.342802f, -0.485794f, -0.012730f, 0.152721f, -0.612948f, -0.107348f, -0.149795f, -1.133775f, 0.813507f, -0.121323f, -1.037352f, 0.949408f, -0.645689f, 0.424853f, 1.190055f, 0.055551f, 0.345244f, 0.476794f, 0.906949f, -0.368187f, -0.675263f, -0.093908f, 0.938461f, 0.103178f, 0.833774f, -0.008922f, 0.368184f, 0.041727f, 0.032575f, -1.141943f, -1.049081f}; + std::initializer_list Grid_data{TypeParam(-1.189605f), TypeParam(-0.312072f), TypeParam(0.459409f), TypeParam(1.033285f), TypeParam(-1.083635f), TypeParam(0.572921f), TypeParam(-1.138649f), TypeParam(-1.147562f), TypeParam(-0.751493f), TypeParam(-0.158500f), TypeParam(0.335153f), TypeParam(-0.912613f), TypeParam(0.924528f), TypeParam(1.085165f), TypeParam(0.073832f), TypeParam(0.976781f), TypeParam(-0.543258f), TypeParam(-0.474714f), TypeParam(-0.154854f), TypeParam(0.131118f), TypeParam(-0.837104f), TypeParam(-0.960885f), TypeParam(0.474040f), TypeParam(0.345992f), TypeParam(1.173923f), TypeParam(-0.489256f), TypeParam(0.423768f), TypeParam(-0.484246f), TypeParam(0.592379f), TypeParam(-0.066474f), TypeParam(0.889570f), TypeParam(0.666682f), TypeParam(0.998817f), TypeParam(0.616675f), TypeParam(0.045084f), TypeParam(1.034127f), TypeParam(-0.704858f), TypeParam(1.131824f), TypeParam(1.172625f), TypeParam(1.146321f), TypeParam(-0.560545f), TypeParam(-0.635830f), TypeParam(0.075922f), TypeParam(0.373677f), TypeParam(0.601953f), TypeParam(0.488043f), TypeParam(1.021787f), TypeParam(-0.300648f), TypeParam(-0.393688f), TypeParam(0.402240f), TypeParam(0.334401f), TypeParam(-0.699993f), TypeParam(0.116070f), TypeParam(-0.911100f), TypeParam(-0.352043f), TypeParam(-0.470968f), TypeParam(1.051900f), TypeParam(-1.080208f), TypeParam(-0.708510f), TypeParam(-1.174356f), TypeParam(0.302647f), TypeParam(-0.923627f), TypeParam(0.388249f), TypeParam(-0.833533f), TypeParam(-0.768697f), TypeParam(-0.613051f), TypeParam(0.180083f), TypeParam(1.102657f), TypeParam(1.124055f), TypeParam(-0.090660f), TypeParam(-1.175396f), TypeParam(-0.396450f), TypeParam(-0.457333f), TypeParam(-0.255235f), TypeParam(0.458506f), TypeParam(0.603882f), TypeParam(0.532050f), TypeParam(0.342802f), TypeParam(-0.485794f), TypeParam(-0.012730f), TypeParam(0.152721f), TypeParam(-0.612948f), TypeParam(-0.107348f), TypeParam(-0.149795f), TypeParam(-1.133775f), TypeParam(0.813507f), TypeParam(-0.121323f), TypeParam(-1.037352f), TypeParam(0.949408f), TypeParam(-0.645689f), TypeParam(0.424853f), TypeParam(1.190055f), TypeParam(0.055551f), TypeParam(0.345244f), TypeParam(0.476794f), TypeParam(0.906949f), TypeParam(-0.368187f), TypeParam(-0.675263f), TypeParam(-0.093908f), TypeParam(0.938461f), TypeParam(0.103178f), TypeParam(0.833774f), TypeParam(-0.008922f), TypeParam(0.368184f), TypeParam(0.041727f), TypeParam(0.032575f), TypeParam(-1.141943f), TypeParam(-1.049081f)}; std::initializer_list Y_shape{2, 2, 3, 3, 2}; - std::initializer_list Y_data{1.069286f, 2.318200f, -0.920922f, -2.244605f, 1.664810f, 0.818932f, -2.244605f, 1.069286f, 0.611457f, -0.355875f, -0.592222f, -0.792880f, -0.025672f, -0.560469f, -0.792880f, 1.664810f, 1.069286f, -2.244605f, 1.121670f, -0.244737f, 0.229768f, 0.570937f, 1.030269f, -0.310838f, 0.570937f, 1.121670f, 0.778967f, -0.450894f, 0.714623f, -0.416612f, -0.276513f, -0.521888f, -0.416612f, 1.030269f, 1.121670f, 0.570937f, -0.295771f, 0.908783f, -0.523833f, 0.908783f, -0.104450f, -0.607720f, -0.124702f, 2.227047f, -0.124702f, -0.124702f, -0.131027f, 1.473970f, 2.227047f, -0.283811f, -0.607720f, -0.283811f, -0.124702f, -1.393580f, 1.244247f, -0.574185f, -1.881107f, -0.574185f, 1.541448f, -1.109310f, -1.300843f, -1.230664f, -1.300843f, -1.300843f, -0.477722f, -0.272528f, -1.230664f, -0.550821f, -1.109310f, -0.550821f, -1.300843f, -2.130628f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(1.069286f), TypeParam(2.318200f), TypeParam(-0.920922f), TypeParam(-2.244605f), TypeParam(1.664810f), TypeParam(0.818932f), TypeParam(-2.244605f), TypeParam(1.069286f), TypeParam(0.611457f), TypeParam(-0.355875f), TypeParam(-0.592222f), TypeParam(-0.792880f), TypeParam(-0.025672f), TypeParam(-0.560469f), TypeParam(-0.792880f), TypeParam(1.664810f), TypeParam(1.069286f), TypeParam(-2.244605f), TypeParam(1.121670f), TypeParam(-0.244737f), TypeParam(0.229768f), TypeParam(0.570937f), TypeParam(1.030269f), TypeParam(-0.310838f), TypeParam(0.570937f), TypeParam(1.121670f), TypeParam(0.778967f), TypeParam(-0.450894f), TypeParam(0.714623f), TypeParam(-0.416612f), TypeParam(-0.276513f), TypeParam(-0.521888f), TypeParam(-0.416612f), TypeParam(1.030269f), TypeParam(1.121670f), TypeParam(0.570937f), TypeParam(-0.295771f), TypeParam(0.908783f), TypeParam(-0.523833f), TypeParam(0.908783f), TypeParam(-0.104450f), TypeParam(-0.607720f), TypeParam(-0.124702f), TypeParam(2.227047f), TypeParam(-0.124702f), TypeParam(-0.124702f), TypeParam(-0.131027f), TypeParam(1.473970f), TypeParam(2.227047f), TypeParam(-0.283811f), TypeParam(-0.607720f), TypeParam(-0.283811f), TypeParam(-0.124702f), TypeParam(-1.393580f), TypeParam(1.244247f), TypeParam(-0.574185f), TypeParam(-1.881107f), TypeParam(-0.574185f), TypeParam(1.541448f), TypeParam(-1.109310f), TypeParam(-1.300843f), TypeParam(-1.230664f), TypeParam(-1.300843f), TypeParam(-1.300843f), TypeParam(-0.477722f), TypeParam(-0.272528f), TypeParam(-1.230664f), TypeParam(-0.550821f), TypeParam(-1.109310f), TypeParam(-0.550821f), TypeParam(-1.300843f), TypeParam(-2.130628f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_4D_nearest_border_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_nearest_border_no_align_corners) { OpTester test("GridSample", 20); std::string mode = "nearest"; std::string padding_mode = "border"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{0.950589f, -1.656624f, 0.767704f, -0.650720f, -1.404308f, -0.531582f, -0.280854f, 0.344309f, -0.959146f, -0.115645f, 0.515696f, -0.114243f, 1.971614f, 0.274268f, 0.543080f, -1.758563f, 1.771011f, 0.934901f, 0.695798f, 1.905137f, 1.598307f, 1.108385f, 0.156008f, 1.290824f}; + std::initializer_list X_data{TypeParam(0.950589f), TypeParam(-1.656624f), TypeParam(0.767704f), TypeParam(-0.650720f), TypeParam(-1.404308f), TypeParam(-0.531582f), TypeParam(-0.280854f), TypeParam(0.344309f), TypeParam(-0.959146f), TypeParam(-0.115645f), TypeParam(0.515696f), TypeParam(-0.114243f), TypeParam(1.971614f), TypeParam(0.274268f), TypeParam(0.543080f), TypeParam(-1.758563f), TypeParam(1.771011f), TypeParam(0.934901f), TypeParam(0.695798f), TypeParam(1.905137f), TypeParam(1.598307f), TypeParam(1.108385f), TypeParam(0.156008f), TypeParam(1.290824f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.482490f, -0.910951f, -0.001676f, -0.442514f, 0.580438f, 1.039346f, -0.159076f, -0.603960f, -0.922037f, -0.705026f, 0.346468f, 0.275332f, 0.646235f, -0.178307f, 0.616600f, -1.069108f, 0.322583f, 1.164952f, -1.187638f, -0.622953f, 0.768203f, -0.187618f, -0.639652f, 0.732078f}; + std::initializer_list Grid_data{TypeParam(0.482490f), TypeParam(-0.910951f), TypeParam(-0.001676f), TypeParam(-0.442514f), TypeParam(0.580438f), TypeParam(1.039346f), TypeParam(-0.159076f), TypeParam(-0.603960f), TypeParam(-0.922037f), TypeParam(-0.705026f), TypeParam(0.346468f), TypeParam(0.275332f), TypeParam(0.646235f), TypeParam(-0.178307f), TypeParam(0.616600f), TypeParam(-1.069108f), TypeParam(0.322583f), TypeParam(1.164952f), TypeParam(-1.187638f), TypeParam(-0.622953f), TypeParam(0.768203f), TypeParam(-0.187618f), TypeParam(-0.639652f), TypeParam(0.732078f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-1.656624f, 0.950589f, -0.531582f, 0.950589f, 0.950589f, -0.650720f, 0.344309f, -0.280854f, -0.114243f, -0.280854f, -0.280854f, -0.115645f, -1.758563f, 0.274268f, 0.934901f, 1.971614f, -1.758563f, 1.771011f, 1.108385f, 1.905137f, 1.290824f, 0.695798f, 1.108385f, 0.156008f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-1.656624f), TypeParam(0.950589f), TypeParam(-0.531582f), TypeParam(0.950589f), TypeParam(0.950589f), TypeParam(-0.650720f), TypeParam(0.344309f), TypeParam(-0.280854f), TypeParam(-0.114243f), TypeParam(-0.280854f), TypeParam(-0.280854f), TypeParam(-0.115645f), TypeParam(-1.758563f), TypeParam(0.274268f), TypeParam(0.934901f), TypeParam(1.971614f), TypeParam(-1.758563f), TypeParam(1.771011f), TypeParam(1.108385f), TypeParam(1.905137f), TypeParam(1.290824f), TypeParam(0.695798f), TypeParam(1.108385f), TypeParam(0.156008f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_5D_nearest_border_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_nearest_border_no_align_corners) { OpTester test("GridSample", 20); std::string mode = "nearest"; std::string padding_mode = "border"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 3, 2}; - std::initializer_list X_data{0.465448f, -0.337086f, -0.870849f, -0.389573f, -0.083941f, 1.306894f, 0.719508f, -0.203690f, -1.143864f, 1.163003f, 0.312170f, -2.008687f, 1.731257f, -0.270431f, 1.095352f, -1.673520f, 0.492743f, 0.521962f, -1.938783f, -0.186813f, -0.836257f, -1.835450f, 0.476500f, -0.123386f, 0.246604f, 1.374159f, -0.158435f, 1.268192f, -0.704226f, -0.195314f, -0.277259f, 0.582961f, -0.340940f, 0.192264f, 0.463124f, -2.719402f, -0.593470f, -1.165777f, 0.566071f, 1.622836f, -0.886798f, 1.874877f, -0.849095f, 0.550185f, 0.604298f, 0.073976f, -0.800372f, -0.097283f, -1.576251f, -0.633278f, -1.776745f, -0.827586f, 0.665697f, 0.884698f, 0.467112f, -0.645219f, -0.510110f, 0.032418f, -1.056009f, -0.206175f, -0.173385f, 0.947787f, 1.937234f, 0.615880f, -0.311580f, 0.770921f, -0.841602f, 1.796220f, 0.479491f, 1.609346f, 1.113868f, -0.453360f}; + std::initializer_list X_data{TypeParam(0.465448f), TypeParam(-0.337086f), TypeParam(-0.870849f), TypeParam(-0.389573f), TypeParam(-0.083941f), TypeParam(1.306894f), TypeParam(0.719508f), TypeParam(-0.203690f), TypeParam(-1.143864f), TypeParam(1.163003f), TypeParam(0.312170f), TypeParam(-2.008687f), TypeParam(1.731257f), TypeParam(-0.270431f), TypeParam(1.095352f), TypeParam(-1.673520f), TypeParam(0.492743f), TypeParam(0.521962f), TypeParam(-1.938783f), TypeParam(-0.186813f), TypeParam(-0.836257f), TypeParam(-1.835450f), TypeParam(0.476500f), TypeParam(-0.123386f), TypeParam(0.246604f), TypeParam(1.374159f), TypeParam(-0.158435f), TypeParam(1.268192f), TypeParam(-0.704226f), TypeParam(-0.195314f), TypeParam(-0.277259f), TypeParam(0.582961f), TypeParam(-0.340940f), TypeParam(0.192264f), TypeParam(0.463124f), TypeParam(-2.719402f), TypeParam(-0.593470f), TypeParam(-1.165777f), TypeParam(0.566071f), TypeParam(1.622836f), TypeParam(-0.886798f), TypeParam(1.874877f), TypeParam(-0.849095f), TypeParam(0.550185f), TypeParam(0.604298f), TypeParam(0.073976f), TypeParam(-0.800372f), TypeParam(-0.097283f), TypeParam(-1.576251f), TypeParam(-0.633278f), TypeParam(-1.776745f), TypeParam(-0.827586f), TypeParam(0.665697f), TypeParam(0.884698f), TypeParam(0.467112f), TypeParam(-0.645219f), TypeParam(-0.510110f), TypeParam(0.032418f), TypeParam(-1.056009f), TypeParam(-0.206175f), TypeParam(-0.173385f), TypeParam(0.947787f), TypeParam(1.937234f), TypeParam(0.615880f), TypeParam(-0.311580f), TypeParam(0.770921f), TypeParam(-0.841602f), TypeParam(1.796220f), TypeParam(0.479491f), TypeParam(1.609346f), TypeParam(1.113868f), TypeParam(-0.453360f)}; std::initializer_list Grid_shape{2, 3, 3, 2, 3}; - std::initializer_list Grid_data{-0.151540f, -0.033291f, -0.597203f, 0.836404f, -0.686848f, -0.485355f, -0.936738f, -1.009057f, 1.065352f, -0.926635f, -0.165670f, -0.347352f, 0.439545f, 0.320963f, -0.919909f, 1.077689f, -1.195359f, 0.118687f, -0.100253f, -0.278089f, 0.817760f, 1.013180f, 0.156316f, -0.423839f, 0.892139f, 0.753924f, 0.215530f, -0.328214f, 0.050592f, 1.069553f, 0.130134f, -0.236478f, -1.015986f, -0.643059f, 0.866682f, -0.042256f, -0.079912f, 0.467233f, -0.789513f, -0.081063f, -0.337505f, 0.627865f, 0.976589f, 0.753489f, 0.894667f, -1.072442f, -0.426020f, 0.142099f, -1.019226f, 0.325527f, -0.786578f, 0.514215f, 0.971223f, -1.026539f, 1.005531f, 0.559922f, -0.791906f, 1.148613f, -1.039306f, -0.807864f, -0.596935f, -0.060766f, 0.215484f, -0.352165f, -1.137417f, -0.138518f, 0.910459f, 0.923925f, 0.600710f, 0.174227f, 0.298169f, -0.925092f, 0.485927f, -1.194283f, -0.495564f, -0.315357f, 0.881199f, -0.034981f, -0.546611f, 0.209651f, -0.995724f, -0.317709f, 0.332343f, -0.079474f, -0.126024f, 0.733410f, -0.911554f, -0.605911f, 1.161566f, 0.238787f, -0.194293f, 0.621583f, 0.721901f, -0.200521f, -0.499850f, -0.196149f, 0.435730f, -0.153196f, 0.698401f, -0.978582f, -0.588758f, 0.914808f, 0.157427f, 0.241646f, 0.394674f, -0.283552f, -0.479889f, 0.344261f}; + std::initializer_list Grid_data{TypeParam(-0.151540f), TypeParam(-0.033291f), TypeParam(-0.597203f), TypeParam(0.836404f), TypeParam(-0.686848f), TypeParam(-0.485355f), TypeParam(-0.936738f), TypeParam(-1.009057f), TypeParam(1.065352f), TypeParam(-0.926635f), TypeParam(-0.165670f), TypeParam(-0.347352f), TypeParam(0.439545f), TypeParam(0.320963f), TypeParam(-0.919909f), TypeParam(1.077689f), TypeParam(-1.195359f), TypeParam(0.118687f), TypeParam(-0.100253f), TypeParam(-0.278089f), TypeParam(0.817760f), TypeParam(1.013180f), TypeParam(0.156316f), TypeParam(-0.423839f), TypeParam(0.892139f), TypeParam(0.753924f), TypeParam(0.215530f), TypeParam(-0.328214f), TypeParam(0.050592f), TypeParam(1.069553f), TypeParam(0.130134f), TypeParam(-0.236478f), TypeParam(-1.015986f), TypeParam(-0.643059f), TypeParam(0.866682f), TypeParam(-0.042256f), TypeParam(-0.079912f), TypeParam(0.467233f), TypeParam(-0.789513f), TypeParam(-0.081063f), TypeParam(-0.337505f), TypeParam(0.627865f), TypeParam(0.976589f), TypeParam(0.753489f), TypeParam(0.894667f), TypeParam(-1.072442f), TypeParam(-0.426020f), TypeParam(0.142099f), TypeParam(-1.019226f), TypeParam(0.325527f), TypeParam(-0.786578f), TypeParam(0.514215f), TypeParam(0.971223f), TypeParam(-1.026539f), TypeParam(1.005531f), TypeParam(0.559922f), TypeParam(-0.791906f), TypeParam(1.148613f), TypeParam(-1.039306f), TypeParam(-0.807864f), TypeParam(-0.596935f), TypeParam(-0.060766f), TypeParam(0.215484f), TypeParam(-0.352165f), TypeParam(-1.137417f), TypeParam(-0.138518f), TypeParam(0.910459f), TypeParam(0.923925f), TypeParam(0.600710f), TypeParam(0.174227f), TypeParam(0.298169f), TypeParam(-0.925092f), TypeParam(0.485927f), TypeParam(-1.194283f), TypeParam(-0.495564f), TypeParam(-0.315357f), TypeParam(0.881199f), TypeParam(-0.034981f), TypeParam(-0.546611f), TypeParam(0.209651f), TypeParam(-0.995724f), TypeParam(-0.317709f), TypeParam(0.332343f), TypeParam(-0.079474f), TypeParam(-0.126024f), TypeParam(0.733410f), TypeParam(-0.911554f), TypeParam(-0.605911f), TypeParam(1.161566f), TypeParam(0.238787f), TypeParam(-0.194293f), TypeParam(0.621583f), TypeParam(0.721901f), TypeParam(-0.200521f), TypeParam(-0.499850f), TypeParam(-0.196149f), TypeParam(0.435730f), TypeParam(-0.153196f), TypeParam(0.698401f), TypeParam(-0.978582f), TypeParam(-0.588758f), TypeParam(0.914808f), TypeParam(0.157427f), TypeParam(0.241646f), TypeParam(0.394674f), TypeParam(-0.283552f), TypeParam(-0.479889f), TypeParam(0.344261f)}; std::initializer_list Y_shape{2, 2, 3, 3, 2}; - std::initializer_list Y_data{-0.870849f, -0.337086f, 1.731257f, -0.870849f, -0.389573f, -0.203690f, 1.095352f, -0.389573f, -2.008687f, 1.095352f, -0.389573f, 0.312170f, -0.083941f, 1.731257f, 0.521962f, 0.719508f, -0.870849f, 1.306894f, -0.836257f, -0.186813f, -0.277259f, -0.836257f, -1.835450f, 1.374159f, -0.340940f, -1.835450f, -0.195314f, -0.340940f, -1.835450f, -0.704226f, 0.476500f, -0.277259f, -2.719402f, 0.246604f, -0.836257f, -0.123386f, 1.874877f, -1.165777f, 0.604298f, -0.849095f, 0.884698f, 1.622836f, -1.165777f, -0.800372f, 0.566071f, 0.604298f, -0.886798f, -0.800372f, 0.665697f, -0.849095f, -0.827586f, -1.576251f, -0.827586f, -1.576251f, -0.206175f, -0.645219f, 1.937234f, -0.173385f, -0.453360f, 0.032418f, -0.645219f, -0.311580f, -0.510110f, 1.937234f, -1.056009f, -0.311580f, 1.113868f, -0.173385f, 1.609346f, -0.841602f, 1.609346f, -0.841602f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.870849f), TypeParam(-0.337086f), TypeParam(1.731257f), TypeParam(-0.870849f), TypeParam(-0.389573f), TypeParam(-0.203690f), TypeParam(1.095352f), TypeParam(-0.389573f), TypeParam(-2.008687f), TypeParam(1.095352f), TypeParam(-0.389573f), TypeParam(0.312170f), TypeParam(-0.083941f), TypeParam(1.731257f), TypeParam(0.521962f), TypeParam(0.719508f), TypeParam(-0.870849f), TypeParam(1.306894f), TypeParam(-0.836257f), TypeParam(-0.186813f), TypeParam(-0.277259f), TypeParam(-0.836257f), TypeParam(-1.835450f), TypeParam(1.374159f), TypeParam(-0.340940f), TypeParam(-1.835450f), TypeParam(-0.195314f), TypeParam(-0.340940f), TypeParam(-1.835450f), TypeParam(-0.704226f), TypeParam(0.476500f), TypeParam(-0.277259f), TypeParam(-2.719402f), TypeParam(0.246604f), TypeParam(-0.836257f), TypeParam(-0.123386f), TypeParam(1.874877f), TypeParam(-1.165777f), TypeParam(0.604298f), TypeParam(-0.849095f), TypeParam(0.884698f), TypeParam(1.622836f), TypeParam(-1.165777f), TypeParam(-0.800372f), TypeParam(0.566071f), TypeParam(0.604298f), TypeParam(-0.886798f), TypeParam(-0.800372f), TypeParam(0.665697f), TypeParam(-0.849095f), TypeParam(-0.827586f), TypeParam(-1.576251f), TypeParam(-0.827586f), TypeParam(-1.576251f), TypeParam(-0.206175f), TypeParam(-0.645219f), TypeParam(1.937234f), TypeParam(-0.173385f), TypeParam(-0.453360f), TypeParam(0.032418f), TypeParam(-0.645219f), TypeParam(-0.311580f), TypeParam(-0.510110f), TypeParam(1.937234f), TypeParam(-1.056009f), TypeParam(-0.311580f), TypeParam(1.113868f), TypeParam(-0.173385f), TypeParam(1.609346f), TypeParam(-0.841602f), TypeParam(1.609346f), TypeParam(-0.841602f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_4D_nearest_reflection_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_nearest_reflection_align_corners) { OpTester test("GridSample", 20); std::string mode = "nearest"; std::string padding_mode = "reflection"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{0.079043f, 0.407494f, 1.038992f, -0.437542f, 0.991216f, 0.409636f, 1.050403f, -0.687172f, -2.021689f, 0.789633f, 0.538178f, 0.414847f, 2.221617f, -0.254833f, -0.179968f, -0.952356f, -1.213159f, 0.499103f, -0.374865f, 0.441938f, -0.114847f, 0.716887f, 1.059090f, 0.438870f}; + std::initializer_list X_data{TypeParam(0.079043f), TypeParam(0.407494f), TypeParam(1.038992f), TypeParam(-0.437542f), TypeParam(0.991216f), TypeParam(0.409636f), TypeParam(1.050403f), TypeParam(-0.687172f), TypeParam(-2.021689f), TypeParam(0.789633f), TypeParam(0.538178f), TypeParam(0.414847f), TypeParam(2.221617f), TypeParam(-0.254833f), TypeParam(-0.179968f), TypeParam(-0.952356f), TypeParam(-1.213159f), TypeParam(0.499103f), TypeParam(-0.374865f), TypeParam(0.441938f), TypeParam(-0.114847f), TypeParam(0.716887f), TypeParam(1.059090f), TypeParam(0.438870f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.355147f, -0.222342f, -1.197658f, 0.844060f, 1.188586f, 0.605435f, 1.174232f, 0.327060f, -0.094032f, -0.955794f, -1.048806f, -0.826196f, -0.304468f, 0.698768f, -0.495101f, -0.046607f, -0.016936f, -0.784415f, -0.032484f, 1.158664f, 0.959105f, 0.913943f, -0.118352f, 0.021282f}; + std::initializer_list Grid_data{TypeParam(0.355147f), TypeParam(-0.222342f), TypeParam(-1.197658f), TypeParam(0.844060f), TypeParam(1.188586f), TypeParam(0.605435f), TypeParam(1.174232f), TypeParam(0.327060f), TypeParam(-0.094032f), TypeParam(-0.955794f), TypeParam(-1.048806f), TypeParam(-0.826196f), TypeParam(-0.304468f), TypeParam(0.698768f), TypeParam(-0.495101f), TypeParam(-0.046607f), TypeParam(-0.016936f), TypeParam(-0.784415f), TypeParam(-0.032484f), TypeParam(1.158664f), TypeParam(0.959105f), TypeParam(0.913943f), TypeParam(-0.118352f), TypeParam(0.021282f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-0.437542f, 0.991216f, 0.409636f, -0.437542f, 0.079043f, 0.079043f, 0.789633f, 0.538178f, 0.414847f, 0.789633f, 1.050403f, 1.050403f, -1.213159f, -0.179968f, 2.221617f, -1.213159f, 0.499103f, -0.179968f, 1.059090f, -0.114847f, -0.374865f, 1.059090f, 0.438870f, -0.114847f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.437542f), TypeParam(0.991216f), TypeParam(0.409636f), TypeParam(-0.437542f), TypeParam(0.079043f), TypeParam(0.079043f), TypeParam(0.789633f), TypeParam(0.538178f), TypeParam(0.414847f), TypeParam(0.789633f), TypeParam(1.050403f), TypeParam(1.050403f), TypeParam(-1.213159f), TypeParam(-0.179968f), TypeParam(2.221617f), TypeParam(-1.213159f), TypeParam(0.499103f), TypeParam(-0.179968f), TypeParam(1.059090f), TypeParam(-0.114847f), TypeParam(-0.374865f), TypeParam(1.059090f), TypeParam(0.438870f), TypeParam(-0.114847f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_5D_nearest_reflection_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_nearest_reflection_align_corners) { OpTester test("GridSample", 20); std::string mode = "nearest"; std::string padding_mode = "reflection"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 3, 2}; - std::initializer_list X_data{0.189379f, 0.825309f, -0.701365f, 0.787800f, -1.102514f, 0.126954f, 1.824453f, -0.144635f, -1.712534f, 0.361739f, -0.462516f, -2.153102f, 0.536963f, 0.581639f, -1.325014f, -1.314673f, -0.524797f, -1.304159f, -1.093757f, -1.703444f, -0.672976f, 0.505303f, 1.497654f, -0.545441f, -1.334648f, 0.474489f, 0.484384f, 0.434399f, -0.733471f, 0.452991f, 0.324606f, -1.307459f, -0.640603f, -0.450100f, 0.772854f, 1.281813f, -0.481714f, 1.224667f, -0.437546f, 0.371986f, -0.320368f, -1.011020f, -1.199298f, 0.213302f, 1.795444f, 0.409271f, 1.328065f, -1.037527f, 0.224494f, 0.217863f, -0.925740f, 0.344755f, -1.445667f, -0.935542f, -0.427280f, -2.010803f, -1.174929f, 1.434105f, -1.168630f, 0.321896f, -0.561974f, -0.209305f, -1.063838f, 1.451708f, 0.266913f, -0.132535f, 0.798299f, 0.619547f, -0.324459f, 0.255630f, 0.488773f, -0.142060f}; + std::initializer_list X_data{TypeParam(0.189379f), TypeParam(0.825309f), TypeParam(-0.701365f), TypeParam(0.787800f), TypeParam(-1.102514f), TypeParam(0.126954f), TypeParam(1.824453f), TypeParam(-0.144635f), TypeParam(-1.712534f), TypeParam(0.361739f), TypeParam(-0.462516f), TypeParam(-2.153102f), TypeParam(0.536963f), TypeParam(0.581639f), TypeParam(-1.325014f), TypeParam(-1.314673f), TypeParam(-0.524797f), TypeParam(-1.304159f), TypeParam(-1.093757f), TypeParam(-1.703444f), TypeParam(-0.672976f), TypeParam(0.505303f), TypeParam(1.497654f), TypeParam(-0.545441f), TypeParam(-1.334648f), TypeParam(0.474489f), TypeParam(0.484384f), TypeParam(0.434399f), TypeParam(-0.733471f), TypeParam(0.452991f), TypeParam(0.324606f), TypeParam(-1.307459f), TypeParam(-0.640603f), TypeParam(-0.450100f), TypeParam(0.772854f), TypeParam(1.281813f), TypeParam(-0.481714f), TypeParam(1.224667f), TypeParam(-0.437546f), TypeParam(0.371986f), TypeParam(-0.320368f), TypeParam(-1.011020f), TypeParam(-1.199298f), TypeParam(0.213302f), TypeParam(1.795444f), TypeParam(0.409271f), TypeParam(1.328065f), TypeParam(-1.037527f), TypeParam(0.224494f), TypeParam(0.217863f), TypeParam(-0.925740f), TypeParam(0.344755f), TypeParam(-1.445667f), TypeParam(-0.935542f), TypeParam(-0.427280f), TypeParam(-2.010803f), TypeParam(-1.174929f), TypeParam(1.434105f), TypeParam(-1.168630f), TypeParam(0.321896f), TypeParam(-0.561974f), TypeParam(-0.209305f), TypeParam(-1.063838f), TypeParam(1.451708f), TypeParam(0.266913f), TypeParam(-0.132535f), TypeParam(0.798299f), TypeParam(0.619547f), TypeParam(-0.324459f), TypeParam(0.255630f), TypeParam(0.488773f), TypeParam(-0.142060f)}; std::initializer_list Grid_shape{2, 3, 3, 2, 3}; - std::initializer_list Grid_data{-0.034431f, 1.048250f, 0.160255f, -0.446426f, 0.879791f, -0.683555f, 0.039704f, 0.269729f, 0.538601f, -1.107191f, 0.058867f, -0.310704f, 0.778040f, 0.403733f, 0.480956f, 0.721512f, -0.268657f, -0.076883f, 0.962704f, -0.967187f, -0.829464f, 0.087786f, -0.475353f, 0.068725f, 1.060032f, -0.139108f, -1.023162f, -0.545493f, 1.102040f, -0.263627f, -0.526173f, 0.540152f, 0.148556f, -1.058015f, 0.999344f, 0.675750f, 1.043022f, 0.525119f, -0.404585f, -0.391737f, 0.581547f, -0.232625f, 0.235264f, -1.162786f, -0.593187f, 0.445737f, -0.059159f, -0.576901f, -1.046721f, 0.762672f, -0.241271f, -1.179040f, 1.157741f, 0.583952f, -0.717767f, -0.875798f, 1.159575f, 0.005010f, -0.721707f, 0.690536f, -0.249959f, 0.082204f, -0.625120f, -1.016394f, -0.796947f, -0.131764f, -0.868737f, 1.182731f, 0.012988f, -0.459398f, 0.474264f, -1.063883f, -0.613791f, 0.450721f, -1.019595f, 0.598084f, 0.100866f, -1.000569f, -1.190919f, 0.379261f, 0.567202f, -0.239888f, -1.061107f, -0.691616f, 0.127540f, 0.043657f, 0.307172f, 0.212184f, -0.062900f, 0.633272f, 1.164016f, 0.999377f, 1.090411f, -0.405004f, -0.409578f, -0.132722f, 0.354671f, 0.485734f, -0.106963f, -0.775112f, -0.905400f, 1.155262f, -0.322627f, -0.162203f, -0.735432f, -0.594912f, 0.263568f, 0.505424f}; + std::initializer_list Grid_data{TypeParam(-0.034431f), TypeParam(1.048250f), TypeParam(0.160255f), TypeParam(-0.446426f), TypeParam(0.879791f), TypeParam(-0.683555f), TypeParam(0.039704f), TypeParam(0.269729f), TypeParam(0.538601f), TypeParam(-1.107191f), TypeParam(0.058867f), TypeParam(-0.310704f), TypeParam(0.778040f), TypeParam(0.403733f), TypeParam(0.480956f), TypeParam(0.721512f), TypeParam(-0.268657f), TypeParam(-0.076883f), TypeParam(0.962704f), TypeParam(-0.967187f), TypeParam(-0.829464f), TypeParam(0.087786f), TypeParam(-0.475353f), TypeParam(0.068725f), TypeParam(1.060032f), TypeParam(-0.139108f), TypeParam(-1.023162f), TypeParam(-0.545493f), TypeParam(1.102040f), TypeParam(-0.263627f), TypeParam(-0.526173f), TypeParam(0.540152f), TypeParam(0.148556f), TypeParam(-1.058015f), TypeParam(0.999344f), TypeParam(0.675750f), TypeParam(1.043022f), TypeParam(0.525119f), TypeParam(-0.404585f), TypeParam(-0.391737f), TypeParam(0.581547f), TypeParam(-0.232625f), TypeParam(0.235264f), TypeParam(-1.162786f), TypeParam(-0.593187f), TypeParam(0.445737f), TypeParam(-0.059159f), TypeParam(-0.576901f), TypeParam(-1.046721f), TypeParam(0.762672f), TypeParam(-0.241271f), TypeParam(-1.179040f), TypeParam(1.157741f), TypeParam(0.583952f), TypeParam(-0.717767f), TypeParam(-0.875798f), TypeParam(1.159575f), TypeParam(0.005010f), TypeParam(-0.721707f), TypeParam(0.690536f), TypeParam(-0.249959f), TypeParam(0.082204f), TypeParam(-0.625120f), TypeParam(-1.016394f), TypeParam(-0.796947f), TypeParam(-0.131764f), TypeParam(-0.868737f), TypeParam(1.182731f), TypeParam(0.012988f), TypeParam(-0.459398f), TypeParam(0.474264f), TypeParam(-1.063883f), TypeParam(-0.613791f), TypeParam(0.450721f), TypeParam(-1.019595f), TypeParam(0.598084f), TypeParam(0.100866f), TypeParam(-1.000569f), TypeParam(-1.190919f), TypeParam(0.379261f), TypeParam(0.567202f), TypeParam(-0.239888f), TypeParam(-1.061107f), TypeParam(-0.691616f), TypeParam(0.127540f), TypeParam(0.043657f), TypeParam(0.307172f), TypeParam(0.212184f), TypeParam(-0.062900f), TypeParam(0.633272f), TypeParam(1.164016f), TypeParam(0.999377f), TypeParam(1.090411f), TypeParam(-0.405004f), TypeParam(-0.409578f), TypeParam(-0.132722f), TypeParam(0.354671f), TypeParam(0.485734f), TypeParam(-0.106963f), TypeParam(-0.775112f), TypeParam(-0.905400f), TypeParam(1.155262f), TypeParam(-0.322627f), TypeParam(-0.162203f), TypeParam(-0.735432f), TypeParam(-0.594912f), TypeParam(0.263568f), TypeParam(0.505424f)}; std::initializer_list Y_shape{2, 2, 3, 3, 2}; - std::initializer_list Y_data{-0.462516f, -1.102514f, -1.314673f, -1.712534f, 0.361739f, 0.361739f, 0.825309f, 0.361739f, 0.787800f, -0.462516f, -0.462516f, -0.524797f, -2.153102f, -0.462516f, 0.825309f, 0.787800f, -0.462516f, -0.524797f, -0.733471f, 1.497654f, -0.450100f, 0.484384f, 0.434399f, 0.434399f, -1.703444f, 0.434399f, 0.505303f, -0.733471f, -0.733471f, 0.772854f, 0.452991f, -0.733471f, -1.703444f, 0.505303f, -0.733471f, 0.772854f, 0.224494f, 0.217863f, -0.437546f, -1.199298f, 1.328065f, -0.437546f, -0.437546f, 0.371986f, -0.925740f, -0.481714f, 0.409271f, 0.344755f, -0.935542f, 1.795444f, 0.409271f, 0.224494f, -0.437546f, -0.925740f, 0.798299f, 0.619547f, -1.174929f, -0.561974f, 0.266913f, -1.174929f, -1.174929f, 1.434105f, -0.324459f, -0.427280f, 1.451708f, 0.255630f, -0.142060f, -1.063838f, 1.451708f, 0.798299f, -1.174929f, -0.324459f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.462516f), TypeParam(-1.102514f), TypeParam(-1.314673f), TypeParam(-1.712534f), TypeParam(0.361739f), TypeParam(0.361739f), TypeParam(0.825309f), TypeParam(0.361739f), TypeParam(0.787800f), TypeParam(-0.462516f), TypeParam(-0.462516f), TypeParam(-0.524797f), TypeParam(-2.153102f), TypeParam(-0.462516f), TypeParam(0.825309f), TypeParam(0.787800f), TypeParam(-0.462516f), TypeParam(-0.524797f), TypeParam(-0.733471f), TypeParam(1.497654f), TypeParam(-0.450100f), TypeParam(0.484384f), TypeParam(0.434399f), TypeParam(0.434399f), TypeParam(-1.703444f), TypeParam(0.434399f), TypeParam(0.505303f), TypeParam(-0.733471f), TypeParam(-0.733471f), TypeParam(0.772854f), TypeParam(0.452991f), TypeParam(-0.733471f), TypeParam(-1.703444f), TypeParam(0.505303f), TypeParam(-0.733471f), TypeParam(0.772854f), TypeParam(0.224494f), TypeParam(0.217863f), TypeParam(-0.437546f), TypeParam(-1.199298f), TypeParam(1.328065f), TypeParam(-0.437546f), TypeParam(-0.437546f), TypeParam(0.371986f), TypeParam(-0.925740f), TypeParam(-0.481714f), TypeParam(0.409271f), TypeParam(0.344755f), TypeParam(-0.935542f), TypeParam(1.795444f), TypeParam(0.409271f), TypeParam(0.224494f), TypeParam(-0.437546f), TypeParam(-0.925740f), TypeParam(0.798299f), TypeParam(0.619547f), TypeParam(-1.174929f), TypeParam(-0.561974f), TypeParam(0.266913f), TypeParam(-1.174929f), TypeParam(-1.174929f), TypeParam(1.434105f), TypeParam(-0.324459f), TypeParam(-0.427280f), TypeParam(1.451708f), TypeParam(0.255630f), TypeParam(-0.142060f), TypeParam(-1.063838f), TypeParam(1.451708f), TypeParam(0.798299f), TypeParam(-1.174929f), TypeParam(-0.324459f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_4D_nearest_reflection_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_nearest_reflection_no_align_corners) { OpTester test("GridSample", 20); std::string mode = "nearest"; std::string padding_mode = "reflection"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-0.769854f, -0.805659f, 0.813652f, -0.010183f, 0.276463f, -0.771678f, -2.563015f, -1.243904f, 2.365071f, 0.730651f, -0.068795f, -1.495438f, 0.211578f, -1.042373f, 0.884036f, -0.746288f, 1.011368f, 0.194463f, -0.307214f, 0.556053f, 0.629364f, 0.083601f, 0.248627f, -0.822453f}; + std::initializer_list X_data{TypeParam(-0.769854f), TypeParam(-0.805659f), TypeParam(0.813652f), TypeParam(-0.010183f), TypeParam(0.276463f), TypeParam(-0.771678f), TypeParam(-2.563015f), TypeParam(-1.243904f), TypeParam(2.365071f), TypeParam(0.730651f), TypeParam(-0.068795f), TypeParam(-1.495438f), TypeParam(0.211578f), TypeParam(-1.042373f), TypeParam(0.884036f), TypeParam(-0.746288f), TypeParam(1.011368f), TypeParam(0.194463f), TypeParam(-0.307214f), TypeParam(0.556053f), TypeParam(0.629364f), TypeParam(0.083601f), TypeParam(0.248627f), TypeParam(-0.822453f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.569884f, 1.163780f, -0.977608f, -0.145509f, 0.651234f, 1.099753f, -0.853766f, 0.509955f, 0.495437f, 0.723445f, -0.827299f, 0.856340f, -0.522676f, -0.738659f, 0.238269f, 1.016568f, -0.794666f, 0.640690f, -0.137431f, 0.383085f, 0.936085f, 0.325824f, -0.996188f, -0.361291f}; + std::initializer_list Grid_data{TypeParam(0.569884f), TypeParam(1.163780f), TypeParam(-0.977608f), TypeParam(-0.145509f), TypeParam(0.651234f), TypeParam(1.099753f), TypeParam(-0.853766f), TypeParam(0.509955f), TypeParam(0.495437f), TypeParam(0.723445f), TypeParam(-0.827299f), TypeParam(0.856340f), TypeParam(-0.522676f), TypeParam(-0.738659f), TypeParam(0.238269f), TypeParam(1.016568f), TypeParam(-0.794666f), TypeParam(0.640690f), TypeParam(-0.137431f), TypeParam(0.383085f), TypeParam(0.936085f), TypeParam(0.325824f), TypeParam(-0.996188f), TypeParam(-0.361291f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-0.771678f, 0.813652f, -0.771678f, 0.276463f, -0.771678f, 0.276463f, -1.495438f, 2.365071f, -1.495438f, -0.068795f, -1.495438f, -0.068795f, 0.211578f, 0.194463f, 1.011368f, 1.011368f, -0.746288f, 0.211578f, -0.307214f, -0.822453f, 0.248627f, 0.248627f, 0.083601f, -0.307214f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.771678f), TypeParam(0.813652f), TypeParam(-0.771678f), TypeParam(0.276463f), TypeParam(-0.771678f), TypeParam(0.276463f), TypeParam(-1.495438f), TypeParam(2.365071f), TypeParam(-1.495438f), TypeParam(-0.068795f), TypeParam(-1.495438f), TypeParam(-0.068795f), TypeParam(0.211578f), TypeParam(0.194463f), TypeParam(1.011368f), TypeParam(1.011368f), TypeParam(-0.746288f), TypeParam(0.211578f), TypeParam(-0.307214f), TypeParam(-0.822453f), TypeParam(0.248627f), TypeParam(0.248627f), TypeParam(0.083601f), TypeParam(-0.307214f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_5D_nearest_reflection_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_nearest_reflection_no_align_corners) { OpTester test("GridSample", 20); std::string mode = "nearest"; std::string padding_mode = "reflection"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 3, 2}; - std::initializer_list X_data{-0.185898f, 0.403325f, 0.737314f, 0.545995f, -1.010481f, -1.204522f, -0.147342f, 0.232425f, -1.339485f, 0.013892f, -1.098319f, 0.478079f, 0.051159f, -0.906061f, -0.428560f, 0.583460f, 1.137472f, 1.487881f, 1.349931f, -0.118774f, 0.436410f, 1.334689f, -1.115846f, 0.159820f, 0.617671f, 0.546630f, 1.861115f, 0.500044f, 0.623446f, 0.541840f, -0.279259f, -0.573875f, 0.783115f, -1.125017f, -1.166457f, -0.827232f, 0.273074f, 0.702953f, 1.288608f, -1.037043f, 0.021860f, 0.575628f, -0.034170f, 1.400741f, 0.508057f, 0.994702f, -2.267981f, 1.677437f, 0.175134f, 0.712679f, -0.440408f, -1.248550f, 1.618839f, -0.214598f, 0.486398f, -0.478466f, 0.912471f, 0.499651f, -0.886606f, -0.929524f, 0.449260f, 0.017969f, -0.050906f, 1.799695f, -0.033007f, -1.884108f, -1.392415f, -0.852990f, -0.052969f, 0.819434f, 0.089723f, 0.598047f}; + std::initializer_list X_data{TypeParam(-0.185898f), TypeParam(0.403325f), TypeParam(0.737314f), TypeParam(0.545995f), TypeParam(-1.010481f), TypeParam(-1.204522f), TypeParam(-0.147342f), TypeParam(0.232425f), TypeParam(-1.339485f), TypeParam(0.013892f), TypeParam(-1.098319f), TypeParam(0.478079f), TypeParam(0.051159f), TypeParam(-0.906061f), TypeParam(-0.428560f), TypeParam(0.583460f), TypeParam(1.137472f), TypeParam(1.487881f), TypeParam(1.349931f), TypeParam(-0.118774f), TypeParam(0.436410f), TypeParam(1.334689f), TypeParam(-1.115846f), TypeParam(0.159820f), TypeParam(0.617671f), TypeParam(0.546630f), TypeParam(1.861115f), TypeParam(0.500044f), TypeParam(0.623446f), TypeParam(0.541840f), TypeParam(-0.279259f), TypeParam(-0.573875f), TypeParam(0.783115f), TypeParam(-1.125017f), TypeParam(-1.166457f), TypeParam(-0.827232f), TypeParam(0.273074f), TypeParam(0.702953f), TypeParam(1.288608f), TypeParam(-1.037043f), TypeParam(0.021860f), TypeParam(0.575628f), TypeParam(-0.034170f), TypeParam(1.400741f), TypeParam(0.508057f), TypeParam(0.994702f), TypeParam(-2.267981f), TypeParam(1.677437f), TypeParam(0.175134f), TypeParam(0.712679f), TypeParam(-0.440408f), TypeParam(-1.248550f), TypeParam(1.618839f), TypeParam(-0.214598f), TypeParam(0.486398f), TypeParam(-0.478466f), TypeParam(0.912471f), TypeParam(0.499651f), TypeParam(-0.886606f), TypeParam(-0.929524f), TypeParam(0.449260f), TypeParam(0.017969f), TypeParam(-0.050906f), TypeParam(1.799695f), TypeParam(-0.033007f), TypeParam(-1.884108f), TypeParam(-1.392415f), TypeParam(-0.852990f), TypeParam(-0.052969f), TypeParam(0.819434f), TypeParam(0.089723f), TypeParam(0.598047f)}; std::initializer_list Grid_shape{2, 3, 3, 2, 3}; - std::initializer_list Grid_data{-0.118828f, 0.082315f, 0.328488f, -0.834821f, -0.138863f, -0.988801f, -0.976128f, 0.156412f, -1.171383f, 0.319534f, -1.105438f, -0.834991f, -0.248995f, -1.145138f, 0.969159f, 0.983228f, -0.626795f, 0.251376f, 0.613890f, 0.381328f, -0.160747f, -1.131853f, 0.872567f, -1.052516f, -0.222240f, 0.074438f, -0.395210f, -0.438906f, -1.037125f, 0.066119f, -0.136254f, 1.046163f, -0.395065f, 0.927498f, 0.056808f, -0.539139f, -0.285382f, -0.136177f, 0.012430f, -0.197703f, 0.356128f, 0.988219f, 0.188620f, 0.434655f, 0.741024f, 0.258662f, 0.553165f, 0.629461f, 1.123216f, -1.095185f, 0.410630f, -0.054374f, -0.215508f, -0.462650f, 0.721441f, 1.097745f, -0.979308f, 0.648336f, 0.827460f, 0.209729f, 0.014136f, 0.923431f, 0.035578f, -0.299309f, -0.088614f, 0.385002f, 0.300407f, -0.064744f, 0.378800f, 0.323185f, -0.972071f, 0.299012f, 0.734213f, 0.137618f, -0.109532f, 0.919238f, -1.048417f, -0.547724f, -0.542389f, 1.036863f, -1.160666f, 0.119013f, -1.162427f, -0.039461f, 0.447285f, -0.280625f, 1.164882f, 0.003820f, -0.611796f, 0.309439f, 0.624077f, -0.002384f, 1.026569f, -0.759499f, 0.512014f, 0.681403f, 0.596030f, -0.000440f, 0.342557f, -0.941414f, -0.941707f, -0.074588f, -0.150400f, 0.891031f, 0.871352f, 0.813657f, -0.549640f, -0.942044f}; + std::initializer_list Grid_data{TypeParam(-0.118828f), TypeParam(0.082315f), TypeParam(0.328488f), TypeParam(-0.834821f), TypeParam(-0.138863f), TypeParam(-0.988801f), TypeParam(-0.976128f), TypeParam(0.156412f), TypeParam(-1.171383f), TypeParam(0.319534f), TypeParam(-1.105438f), TypeParam(-0.834991f), TypeParam(-0.248995f), TypeParam(-1.145138f), TypeParam(0.969159f), TypeParam(0.983228f), TypeParam(-0.626795f), TypeParam(0.251376f), TypeParam(0.613890f), TypeParam(0.381328f), TypeParam(-0.160747f), TypeParam(-1.131853f), TypeParam(0.872567f), TypeParam(-1.052516f), TypeParam(-0.222240f), TypeParam(0.074438f), TypeParam(-0.395210f), TypeParam(-0.438906f), TypeParam(-1.037125f), TypeParam(0.066119f), TypeParam(-0.136254f), TypeParam(1.046163f), TypeParam(-0.395065f), TypeParam(0.927498f), TypeParam(0.056808f), TypeParam(-0.539139f), TypeParam(-0.285382f), TypeParam(-0.136177f), TypeParam(0.012430f), TypeParam(-0.197703f), TypeParam(0.356128f), TypeParam(0.988219f), TypeParam(0.188620f), TypeParam(0.434655f), TypeParam(0.741024f), TypeParam(0.258662f), TypeParam(0.553165f), TypeParam(0.629461f), TypeParam(1.123216f), TypeParam(-1.095185f), TypeParam(0.410630f), TypeParam(-0.054374f), TypeParam(-0.215508f), TypeParam(-0.462650f), TypeParam(0.721441f), TypeParam(1.097745f), TypeParam(-0.979308f), TypeParam(0.648336f), TypeParam(0.827460f), TypeParam(0.209729f), TypeParam(0.014136f), TypeParam(0.923431f), TypeParam(0.035578f), TypeParam(-0.299309f), TypeParam(-0.088614f), TypeParam(0.385002f), TypeParam(0.300407f), TypeParam(-0.064744f), TypeParam(0.378800f), TypeParam(0.323185f), TypeParam(-0.972071f), TypeParam(0.299012f), TypeParam(0.734213f), TypeParam(0.137618f), TypeParam(-0.109532f), TypeParam(0.919238f), TypeParam(-1.048417f), TypeParam(-0.547724f), TypeParam(-0.542389f), TypeParam(1.036863f), TypeParam(-1.160666f), TypeParam(0.119013f), TypeParam(-1.162427f), TypeParam(-0.039461f), TypeParam(0.447285f), TypeParam(-0.280625f), TypeParam(1.164882f), TypeParam(0.003820f), TypeParam(-0.611796f), TypeParam(0.309439f), TypeParam(0.624077f), TypeParam(-0.002384f), TypeParam(1.026569f), TypeParam(-0.759499f), TypeParam(0.512014f), TypeParam(0.681403f), TypeParam(0.596030f), TypeParam(-0.000440f), TypeParam(0.342557f), TypeParam(-0.941414f), TypeParam(-0.941707f), TypeParam(-0.074588f), TypeParam(-0.150400f), TypeParam(0.891031f), TypeParam(0.871352f), TypeParam(0.813657f), TypeParam(-0.549640f), TypeParam(-0.942044f)}; std::initializer_list Y_shape{2, 2, 3, 3, 2}; - std::initializer_list Y_data{-1.339485f, 0.737314f, 0.737314f, 0.403325f, 0.051159f, 0.232425f, 0.478079f, -1.010481f, 0.737314f, -0.147342f, -1.010481f, 0.545995f, -1.339485f, 1.137472f, 1.487881f, 1.487881f, -0.906061f, 0.737314f, 1.861115f, 0.436410f, 0.436410f, -0.118774f, -0.279259f, 0.546630f, 0.541840f, -1.115846f, 0.436410f, 0.617671f, -1.115846f, 1.334689f, 1.861115f, -1.166457f, -0.827232f, -0.827232f, -0.573875f, 0.436410f, 0.575628f, 1.677437f, 1.677437f, -0.440408f, -1.248550f, 1.400741f, 0.994702f, 0.702953f, 0.021860f, 1.400741f, -1.248550f, 1.400741f, -1.248550f, 1.618839f, -1.248550f, -0.034170f, 1.618839f, 0.702953f, -0.929524f, -1.884108f, -1.884108f, -0.052969f, 0.819434f, 0.017969f, 1.799695f, -0.478466f, -0.886606f, 0.017969f, 0.819434f, 0.017969f, 0.819434f, 0.089723f, 0.819434f, 0.449260f, 0.089723f, -0.478466f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-1.339485f), TypeParam(0.737314f), TypeParam(0.737314f), TypeParam(0.403325f), TypeParam(0.051159f), TypeParam(0.232425f), TypeParam(0.478079f), TypeParam(-1.010481f), TypeParam(0.737314f), TypeParam(-0.147342f), TypeParam(-1.010481f), TypeParam(0.545995f), TypeParam(-1.339485f), TypeParam(1.137472f), TypeParam(1.487881f), TypeParam(1.487881f), TypeParam(-0.906061f), TypeParam(0.737314f), TypeParam(1.861115f), TypeParam(0.436410f), TypeParam(0.436410f), TypeParam(-0.118774f), TypeParam(-0.279259f), TypeParam(0.546630f), TypeParam(0.541840f), TypeParam(-1.115846f), TypeParam(0.436410f), TypeParam(0.617671f), TypeParam(-1.115846f), TypeParam(1.334689f), TypeParam(1.861115f), TypeParam(-1.166457f), TypeParam(-0.827232f), TypeParam(-0.827232f), TypeParam(-0.573875f), TypeParam(0.436410f), TypeParam(0.575628f), TypeParam(1.677437f), TypeParam(1.677437f), TypeParam(-0.440408f), TypeParam(-1.248550f), TypeParam(1.400741f), TypeParam(0.994702f), TypeParam(0.702953f), TypeParam(0.021860f), TypeParam(1.400741f), TypeParam(-1.248550f), TypeParam(1.400741f), TypeParam(-1.248550f), TypeParam(1.618839f), TypeParam(-1.248550f), TypeParam(-0.034170f), TypeParam(1.618839f), TypeParam(0.702953f), TypeParam(-0.929524f), TypeParam(-1.884108f), TypeParam(-1.884108f), TypeParam(-0.052969f), TypeParam(0.819434f), TypeParam(0.017969f), TypeParam(1.799695f), TypeParam(-0.478466f), TypeParam(-0.886606f), TypeParam(0.017969f), TypeParam(0.819434f), TypeParam(0.017969f), TypeParam(0.819434f), TypeParam(0.089723f), TypeParam(0.819434f), TypeParam(0.449260f), TypeParam(0.089723f), TypeParam(-0.478466f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_zeros_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_zeros_align_corners) { OpTester test("GridSample", 20); std::string mode = "linear"; std::string padding_mode = "zeros"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{0.010274f, 1.493496f, -0.264303f, 0.035897f, -0.751962f, -0.370195f, -0.514836f, 0.399928f, -0.191651f, -0.239505f, -1.931184f, -1.074773f, -0.121908f, 0.050673f, -0.741501f, -0.229127f, -0.360925f, 0.264077f, 1.537180f, 1.603202f, -1.241810f, -0.388456f, -0.609742f, 0.095097f}; + std::initializer_list X_data{TypeParam(0.010274f), TypeParam(1.493496f), TypeParam(-0.264303f), TypeParam(0.035897f), TypeParam(-0.751962f), TypeParam(-0.370195f), TypeParam(-0.514836f), TypeParam(0.399928f), TypeParam(-0.191651f), TypeParam(-0.239505f), TypeParam(-1.931184f), TypeParam(-1.074773f), TypeParam(-0.121908f), TypeParam(0.050673f), TypeParam(-0.741501f), TypeParam(-0.229127f), TypeParam(-0.360925f), TypeParam(0.264077f), TypeParam(1.537180f), TypeParam(1.603202f), TypeParam(-1.241810f), TypeParam(-0.388456f), TypeParam(-0.609742f), TypeParam(0.095097f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{-0.118589f, -0.020968f, -0.893597f, 1.170924f, -0.517539f, 0.698168f, -0.672718f, 0.008056f, 0.410793f, -1.101817f, 0.550440f, -0.918534f, 0.167456f, -0.237959f, 0.687868f, 1.166281f, 0.270439f, -0.034265f, -0.594534f, 0.447403f, -0.577587f, 0.495680f, -0.520113f, 0.813977f}; + std::initializer_list Grid_data{TypeParam(-0.118589f), TypeParam(-0.020968f), TypeParam(-0.893597f), TypeParam(1.170924f), TypeParam(-0.517539f), TypeParam(0.698168f), TypeParam(-0.672718f), TypeParam(0.008056f), TypeParam(0.410793f), TypeParam(-1.101817f), TypeParam(0.550440f), TypeParam(-0.918534f), TypeParam(0.167456f), TypeParam(-0.237959f), TypeParam(0.687868f), TypeParam(1.166281f), TypeParam(0.270439f), TypeParam(-0.034265f), TypeParam(-0.594534f), TypeParam(0.447403f), TypeParam(-0.577587f), TypeParam(0.495680f), TypeParam(-0.520113f), TypeParam(0.813977f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-0.115313f, -0.606595f, -0.518616f, -0.218999f, 0.948961f, 1.063015f, -0.210622f, -1.563324f, -1.265386f, -0.212304f, 0.117155f, 0.159843f, -0.342175f, 0.138844f, -0.402196f, -0.457139f, -0.432849f, -0.286783f, -0.191760f, -0.012426f, -0.621658f, -0.799488f, -0.763820f, -0.551571f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.115313f), TypeParam(-0.606595f), TypeParam(-0.518616f), TypeParam(-0.218999f), TypeParam(0.948961f), TypeParam(1.063015f), TypeParam(-0.210622f), TypeParam(-1.563324f), TypeParam(-1.265386f), TypeParam(-0.212304f), TypeParam(0.117155f), TypeParam(0.159843f), TypeParam(-0.342175f), TypeParam(0.138844f), TypeParam(-0.402196f), TypeParam(-0.457139f), TypeParam(-0.432849f), TypeParam(-0.286783f), TypeParam(-0.191760f), TypeParam(-0.012426f), TypeParam(-0.621658f), TypeParam(-0.799488f), TypeParam(-0.763820f), TypeParam(-0.551571f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_zeros_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_zeros_align_corners) { OpTester test("GridSample", 20); std::string mode = "linear"; std::string padding_mode = "zeros"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 3, 2}; - std::initializer_list X_data{-1.787070f, -0.894227f, -0.113069f, 0.713917f, 0.041566f, -1.847208f, 0.013441f, -1.439041f, 1.051864f, 1.576791f, 1.180527f, -1.457019f, 0.298446f, 1.142738f, -0.961347f, -0.471509f, -0.074154f, 0.047739f, -0.679950f, -2.306940f, -0.552171f, -0.357144f, -0.492247f, -0.455872f, 0.399680f, 0.057915f, -0.362704f, 1.083763f, -0.084941f, -1.691393f, -1.913178f, 0.696366f, 1.172833f, 0.901506f, -1.189840f, -1.197158f, 0.007338f, 0.161468f, -1.048452f, -0.480832f, 0.391235f, 1.056413f, -0.116648f, 0.632195f, 0.840261f, -2.187738f, 0.302910f, -0.956190f, -0.362645f, 0.771747f, 0.524840f, -0.954672f, -1.084612f, -0.525794f, -0.969691f, -1.056405f, -0.364709f, 0.336189f, -0.178281f, 1.015025f, -0.532580f, 0.036602f, -0.434395f, -1.208987f, -1.084039f, 0.642844f, -0.819208f, -0.982898f, -0.109210f, -1.231957f, 1.083089f, -0.870451f}; + std::initializer_list X_data{TypeParam(-1.787070f), TypeParam(-0.894227f), TypeParam(-0.113069f), TypeParam(0.713917f), TypeParam(0.041566f), TypeParam(-1.847208f), TypeParam(0.013441f), TypeParam(-1.439041f), TypeParam(1.051864f), TypeParam(1.576791f), TypeParam(1.180527f), TypeParam(-1.457019f), TypeParam(0.298446f), TypeParam(1.142738f), TypeParam(-0.961347f), TypeParam(-0.471509f), TypeParam(-0.074154f), TypeParam(0.047739f), TypeParam(-0.679950f), TypeParam(-2.306940f), TypeParam(-0.552171f), TypeParam(-0.357144f), TypeParam(-0.492247f), TypeParam(-0.455872f), TypeParam(0.399680f), TypeParam(0.057915f), TypeParam(-0.362704f), TypeParam(1.083763f), TypeParam(-0.084941f), TypeParam(-1.691393f), TypeParam(-1.913178f), TypeParam(0.696366f), TypeParam(1.172833f), TypeParam(0.901506f), TypeParam(-1.189840f), TypeParam(-1.197158f), TypeParam(0.007338f), TypeParam(0.161468f), TypeParam(-1.048452f), TypeParam(-0.480832f), TypeParam(0.391235f), TypeParam(1.056413f), TypeParam(-0.116648f), TypeParam(0.632195f), TypeParam(0.840261f), TypeParam(-2.187738f), TypeParam(0.302910f), TypeParam(-0.956190f), TypeParam(-0.362645f), TypeParam(0.771747f), TypeParam(0.524840f), TypeParam(-0.954672f), TypeParam(-1.084612f), TypeParam(-0.525794f), TypeParam(-0.969691f), TypeParam(-1.056405f), TypeParam(-0.364709f), TypeParam(0.336189f), TypeParam(-0.178281f), TypeParam(1.015025f), TypeParam(-0.532580f), TypeParam(0.036602f), TypeParam(-0.434395f), TypeParam(-1.208987f), TypeParam(-1.084039f), TypeParam(0.642844f), TypeParam(-0.819208f), TypeParam(-0.982898f), TypeParam(-0.109210f), TypeParam(-1.231957f), TypeParam(1.083089f), TypeParam(-0.870451f)}; std::initializer_list Grid_shape{2, 3, 3, 2, 3}; - std::initializer_list Grid_data{0.350638f, -0.554259f, 0.740901f, -1.134597f, -0.450763f, -0.706065f, -0.712365f, -0.727142f, -1.130749f, 0.205940f, -0.237380f, -1.010413f, -0.000494f, -0.199898f, 0.495032f, -0.939943f, -0.337590f, 0.247001f, 0.508664f, 0.090780f, 0.325198f, 1.199561f, -0.415694f, 0.817854f, 1.033666f, -1.061540f, 0.290273f, 0.679739f, -0.187185f, 0.662278f, 0.040817f, 0.913540f, 0.025838f, -0.768267f, 0.911326f, 0.356885f, 1.020923f, 0.297892f, 0.637209f, 0.748214f, 0.202064f, -0.278959f, 0.247841f, -0.836700f, 0.040996f, -0.385697f, 0.075869f, -0.950110f, 0.733227f, -1.107135f, 0.513890f, 0.790272f, -1.099795f, 1.084212f, -0.892061f, -0.235640f, 0.621837f, -0.380523f, 1.069422f, -0.529383f, -0.160661f, -0.784422f, -0.556715f, 1.171015f, 0.902476f, 0.088357f, 0.098667f, -1.018314f, 0.905937f, -0.179914f, -0.500513f, -0.954987f, 0.986618f, 0.569025f, 0.722795f, 0.124254f, -0.814285f, 0.491561f, 0.138395f, 0.402690f, -0.298810f, -0.566298f, 0.985118f, 0.402260f, -0.487031f, 0.107159f, -0.260850f, -0.102620f, 0.672911f, -0.955102f, 1.086040f, 0.807667f, 0.001031f, -0.490841f, 0.244670f, -0.794290f, 0.779461f, -0.634633f, 0.229290f, -1.180597f, 0.574650f, 0.812338f, 0.900697f, 0.097950f, 0.708525f, 0.409153f, 0.804739f, 0.677169f}; + std::initializer_list Grid_data{TypeParam(0.350638f), TypeParam(-0.554259f), TypeParam(0.740901f), TypeParam(-1.134597f), TypeParam(-0.450763f), TypeParam(-0.706065f), TypeParam(-0.712365f), TypeParam(-0.727142f), TypeParam(-1.130749f), TypeParam(0.205940f), TypeParam(-0.237380f), TypeParam(-1.010413f), TypeParam(-0.000494f), TypeParam(-0.199898f), TypeParam(0.495032f), TypeParam(-0.939943f), TypeParam(-0.337590f), TypeParam(0.247001f), TypeParam(0.508664f), TypeParam(0.090780f), TypeParam(0.325198f), TypeParam(1.199561f), TypeParam(-0.415694f), TypeParam(0.817854f), TypeParam(1.033666f), TypeParam(-1.061540f), TypeParam(0.290273f), TypeParam(0.679739f), TypeParam(-0.187185f), TypeParam(0.662278f), TypeParam(0.040817f), TypeParam(0.913540f), TypeParam(0.025838f), TypeParam(-0.768267f), TypeParam(0.911326f), TypeParam(0.356885f), TypeParam(1.020923f), TypeParam(0.297892f), TypeParam(0.637209f), TypeParam(0.748214f), TypeParam(0.202064f), TypeParam(-0.278959f), TypeParam(0.247841f), TypeParam(-0.836700f), TypeParam(0.040996f), TypeParam(-0.385697f), TypeParam(0.075869f), TypeParam(-0.950110f), TypeParam(0.733227f), TypeParam(-1.107135f), TypeParam(0.513890f), TypeParam(0.790272f), TypeParam(-1.099795f), TypeParam(1.084212f), TypeParam(-0.892061f), TypeParam(-0.235640f), TypeParam(0.621837f), TypeParam(-0.380523f), TypeParam(1.069422f), TypeParam(-0.529383f), TypeParam(-0.160661f), TypeParam(-0.784422f), TypeParam(-0.556715f), TypeParam(1.171015f), TypeParam(0.902476f), TypeParam(0.088357f), TypeParam(0.098667f), TypeParam(-1.018314f), TypeParam(0.905937f), TypeParam(-0.179914f), TypeParam(-0.500513f), TypeParam(-0.954987f), TypeParam(0.986618f), TypeParam(0.569025f), TypeParam(0.722795f), TypeParam(0.124254f), TypeParam(-0.814285f), TypeParam(0.491561f), TypeParam(0.138395f), TypeParam(0.402690f), TypeParam(-0.298810f), TypeParam(-0.566298f), TypeParam(0.985118f), TypeParam(0.402260f), TypeParam(-0.487031f), TypeParam(0.107159f), TypeParam(-0.260850f), TypeParam(-0.102620f), TypeParam(0.672911f), TypeParam(-0.955102f), TypeParam(1.086040f), TypeParam(0.807667f), TypeParam(0.001031f), TypeParam(-0.490841f), TypeParam(0.244670f), TypeParam(-0.794290f), TypeParam(0.779461f), TypeParam(-0.634633f), TypeParam(0.229290f), TypeParam(-1.180597f), TypeParam(0.574650f), TypeParam(0.812338f), TypeParam(0.900697f), TypeParam(0.097950f), TypeParam(0.708525f), TypeParam(0.409153f), TypeParam(0.804739f), TypeParam(0.677169f)}; std::initializer_list Y_shape{2, 2, 3, 3, 2}; - std::initializer_list Y_data{0.171946f, -0.411342f, -1.046998f, -0.002345f, 0.246533f, 0.396970f, 0.664278f, 0.199883f, -0.636287f, 0.162358f, -0.061161f, 0.528084f, 0.041846f, 0.750291f, -0.476442f, 0.142258f, -0.067844f, 0.869081f, 0.360025f, -0.406785f, -0.701985f, -0.718142f, 0.519179f, -0.022693f, 0.618451f, 0.708731f, 0.224429f, 0.784241f, -0.812606f, -0.521137f, 0.266524f, 0.190886f, 0.231077f, -0.465330f, 0.204730f, 0.348489f, 0.356190f, 0.256096f, -0.038212f, -0.943162f, 0.258902f, -0.360112f, -0.920536f, 0.126677f, -0.523600f, -0.361337f, -0.154168f, 0.179761f, -1.141155f, -0.423488f, -0.225410f, -0.204886f, -1.162816f, -0.678226f, -0.384409f, -0.146245f, -0.622531f, 0.312188f, -0.828836f, -0.541017f, -0.778291f, -0.602484f, -0.328754f, -0.163964f, -0.508068f, 0.193021f, 0.273133f, -0.217934f, -0.562420f, 0.287725f, -1.097279f, -0.306201f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(0.171946f), TypeParam(-0.411342f), TypeParam(-1.046998f), TypeParam(-0.002345f), TypeParam(0.246533f), TypeParam(0.396970f), TypeParam(0.664278f), TypeParam(0.199883f), TypeParam(-0.636287f), TypeParam(0.162358f), TypeParam(-0.061161f), TypeParam(0.528084f), TypeParam(0.041846f), TypeParam(0.750291f), TypeParam(-0.476442f), TypeParam(0.142258f), TypeParam(-0.067844f), TypeParam(0.869081f), TypeParam(0.360025f), TypeParam(-0.406785f), TypeParam(-0.701985f), TypeParam(-0.718142f), TypeParam(0.519179f), TypeParam(-0.022693f), TypeParam(0.618451f), TypeParam(0.708731f), TypeParam(0.224429f), TypeParam(0.784241f), TypeParam(-0.812606f), TypeParam(-0.521137f), TypeParam(0.266524f), TypeParam(0.190886f), TypeParam(0.231077f), TypeParam(-0.465330f), TypeParam(0.204730f), TypeParam(0.348489f), TypeParam(0.356190f), TypeParam(0.256096f), TypeParam(-0.038212f), TypeParam(-0.943162f), TypeParam(0.258902f), TypeParam(-0.360112f), TypeParam(-0.920536f), TypeParam(0.126677f), TypeParam(-0.523600f), TypeParam(-0.361337f), TypeParam(-0.154168f), TypeParam(0.179761f), TypeParam(-1.141155f), TypeParam(-0.423488f), TypeParam(-0.225410f), TypeParam(-0.204886f), TypeParam(-1.162816f), TypeParam(-0.678226f), TypeParam(-0.384409f), TypeParam(-0.146245f), TypeParam(-0.622531f), TypeParam(0.312188f), TypeParam(-0.828836f), TypeParam(-0.541017f), TypeParam(-0.778291f), TypeParam(-0.602484f), TypeParam(-0.328754f), TypeParam(-0.163964f), TypeParam(-0.508068f), TypeParam(0.193021f), TypeParam(0.273133f), TypeParam(-0.217934f), TypeParam(-0.562420f), TypeParam(0.287725f), TypeParam(-1.097279f), TypeParam(-0.306201f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_zeros_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_zeros_no_align_corners) { OpTester test("GridSample", 20); std::string mode = "linear"; std::string padding_mode = "zeros"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{0.185965f, 0.133937f, -0.763030f, 0.733342f, 1.932445f, -0.582571f, -1.312078f, 0.738952f, 0.444459f, 0.742593f, -0.805960f, -0.202535f, 0.970323f, -0.801176f, 0.277655f, -1.938051f, -1.879800f, 0.287116f, 0.261958f, -0.358247f, -0.107750f, 0.748162f, -0.742330f, 0.344665f}; + std::initializer_list X_data{TypeParam(0.185965f), TypeParam(0.133937f), TypeParam(-0.763030f), TypeParam(0.733342f), TypeParam(1.932445f), TypeParam(-0.582571f), TypeParam(-1.312078f), TypeParam(0.738952f), TypeParam(0.444459f), TypeParam(0.742593f), TypeParam(-0.805960f), TypeParam(-0.202535f), TypeParam(0.970323f), TypeParam(-0.801176f), TypeParam(0.277655f), TypeParam(-1.938051f), TypeParam(-1.879800f), TypeParam(0.287116f), TypeParam(0.261958f), TypeParam(-0.358247f), TypeParam(-0.107750f), TypeParam(0.748162f), TypeParam(-0.742330f), TypeParam(0.344665f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{-0.460252f, 0.734353f, -1.069308f, 1.005361f, 1.198595f, -0.327629f, 0.474026f, 1.196645f, 0.361782f, 0.469280f, 0.440632f, -0.490951f, 0.292918f, -0.639568f, 1.024697f, -0.514217f, 0.274326f, -0.347614f, 0.600117f, 0.019780f, 0.659824f, -0.324940f, -0.704174f, 0.460072f}; + std::initializer_list Grid_data{TypeParam(-0.460252f), TypeParam(0.734353f), TypeParam(-1.069308f), TypeParam(1.005361f), TypeParam(1.198595f), TypeParam(-0.327629f), TypeParam(0.474026f), TypeParam(1.196645f), TypeParam(0.361782f), TypeParam(0.469280f), TypeParam(0.440632f), TypeParam(-0.490951f), TypeParam(0.292918f), TypeParam(-0.639568f), TypeParam(1.024697f), TypeParam(-0.514217f), TypeParam(0.274326f), TypeParam(-0.347614f), TypeParam(0.600117f), TypeParam(0.019780f), TypeParam(0.659824f), TypeParam(-0.324940f), TypeParam(-0.704174f), TypeParam(0.460072f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{1.646426f, 0.409452f, 0.132247f, -0.106052f, -0.009495f, 0.270785f, -0.702581f, -0.170769f, 0.223282f, -0.044740f, 0.006388f, 0.645576f, -0.476802f, -0.504368f, -0.897503f, -1.684608f, -1.162742f, -0.963921f, -0.197266f, -0.050021f, 0.151796f, 0.662485f, 0.175502f, -0.434265f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(1.646426f), TypeParam(0.409452f), TypeParam(0.132247f), TypeParam(-0.106052f), TypeParam(-0.009495f), TypeParam(0.270785f), TypeParam(-0.702581f), TypeParam(-0.170769f), TypeParam(0.223282f), TypeParam(-0.044740f), TypeParam(0.006388f), TypeParam(0.645576f), TypeParam(-0.476802f), TypeParam(-0.504368f), TypeParam(-0.897503f), TypeParam(-1.684608f), TypeParam(-1.162742f), TypeParam(-0.963921f), TypeParam(-0.197266f), TypeParam(-0.050021f), TypeParam(0.151796f), TypeParam(0.662485f), TypeParam(0.175502f), TypeParam(-0.434265f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_zeros_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_zeros_no_align_corners) { OpTester test("GridSample", 20); std::string mode = "linear"; std::string padding_mode = "zeros"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 3, 2}; - std::initializer_list X_data{-0.299262f, -0.304887f, 0.906636f, -0.392850f, -0.050410f, 0.548199f, -1.235108f, -0.475848f, 0.635455f, 0.307462f, -1.241370f, -0.538672f, 0.863466f, 0.799983f, -0.090064f, -0.751721f, 0.956040f, -0.117709f, -2.183699f, -0.484444f, 1.105900f, 0.164466f, 0.720736f, 0.168044f, -0.656400f, 1.770106f, -0.544832f, 1.358424f, 0.981648f, -1.759268f, -0.526924f, 1.322339f, 0.148774f, 0.321413f, -1.257438f, -0.383775f, -2.117908f, -0.077921f, -0.197889f, 0.555813f, -1.517724f, 1.419652f, -0.891774f, 1.684663f, -1.524669f, -2.055758f, -0.299843f, -0.644860f, 0.428609f, -1.704372f, 1.257671f, -0.886508f, -0.029344f, -1.718824f, -0.294273f, 1.537690f, -1.366837f, -1.610098f, 0.650240f, -0.288219f, 0.837292f, 0.431683f, -0.405852f, 0.492271f, 0.416507f, 0.971658f, -0.183526f, 0.615709f, -0.081615f, 1.160796f, 1.431487f, 0.485687f}; + std::initializer_list X_data{TypeParam(-0.299262f), TypeParam(-0.304887f), TypeParam(0.906636f), TypeParam(-0.392850f), TypeParam(-0.050410f), TypeParam(0.548199f), TypeParam(-1.235108f), TypeParam(-0.475848f), TypeParam(0.635455f), TypeParam(0.307462f), TypeParam(-1.241370f), TypeParam(-0.538672f), TypeParam(0.863466f), TypeParam(0.799983f), TypeParam(-0.090064f), TypeParam(-0.751721f), TypeParam(0.956040f), TypeParam(-0.117709f), TypeParam(-2.183699f), TypeParam(-0.484444f), TypeParam(1.105900f), TypeParam(0.164466f), TypeParam(0.720736f), TypeParam(0.168044f), TypeParam(-0.656400f), TypeParam(1.770106f), TypeParam(-0.544832f), TypeParam(1.358424f), TypeParam(0.981648f), TypeParam(-1.759268f), TypeParam(-0.526924f), TypeParam(1.322339f), TypeParam(0.148774f), TypeParam(0.321413f), TypeParam(-1.257438f), TypeParam(-0.383775f), TypeParam(-2.117908f), TypeParam(-0.077921f), TypeParam(-0.197889f), TypeParam(0.555813f), TypeParam(-1.517724f), TypeParam(1.419652f), TypeParam(-0.891774f), TypeParam(1.684663f), TypeParam(-1.524669f), TypeParam(-2.055758f), TypeParam(-0.299843f), TypeParam(-0.644860f), TypeParam(0.428609f), TypeParam(-1.704372f), TypeParam(1.257671f), TypeParam(-0.886508f), TypeParam(-0.029344f), TypeParam(-1.718824f), TypeParam(-0.294273f), TypeParam(1.537690f), TypeParam(-1.366837f), TypeParam(-1.610098f), TypeParam(0.650240f), TypeParam(-0.288219f), TypeParam(0.837292f), TypeParam(0.431683f), TypeParam(-0.405852f), TypeParam(0.492271f), TypeParam(0.416507f), TypeParam(0.971658f), TypeParam(-0.183526f), TypeParam(0.615709f), TypeParam(-0.081615f), TypeParam(1.160796f), TypeParam(1.431487f), TypeParam(0.485687f)}; std::initializer_list Grid_shape{2, 3, 3, 2, 3}; - std::initializer_list Grid_data{0.884040f, -0.825214f, 0.496720f, -0.440955f, 1.195811f, 0.169268f, -1.042100f, 0.206524f, 0.145895f, -1.160650f, 0.240829f, 1.144915f, 0.345332f, -0.006382f, -0.248763f, 0.318888f, -0.534619f, 1.181719f, 1.037350f, 0.560600f, -0.446974f, -1.126746f, -0.690807f, 1.166754f, -1.101454f, -1.145775f, -0.086488f, 0.381780f, -1.194351f, -1.114106f, 0.006524f, -0.402521f, 0.836016f, 0.344533f, -1.041627f, -1.081571f, 0.824102f, -0.212785f, -0.524949f, 0.377977f, -0.235842f, 0.573897f, 0.304308f, -0.519568f, -0.961787f, 0.649611f, -0.720973f, -0.132725f, 0.164074f, -0.698360f, 0.653669f, -0.844065f, 0.294728f, 0.128341f, 0.440293f, -1.177701f, 0.069319f, 0.585007f, -0.768260f, 0.296941f, 0.004702f, 1.018020f, -0.254096f, 0.008198f, -0.521925f, -0.295744f, 0.343532f, -1.157334f, 0.910329f, 0.862921f, 0.508195f, 0.898317f, -0.373544f, 0.273330f, 0.061050f, -0.829794f, -0.461335f, -0.426012f, -0.296704f, -1.065526f, -0.843948f, -0.113955f, -0.182548f, -1.089296f, 0.256401f, 0.653393f, 0.999377f, 1.009925f, -0.838519f, -0.384579f, -0.569276f, 0.220093f, 0.321562f, 0.266984f, 0.701244f, 0.633093f, -0.644096f, 0.823778f, 0.809482f, 0.158802f, -1.044029f, -0.735991f, 0.334411f, 0.414891f, 1.118940f, 0.610743f, 0.434932f, -0.040928f}; + std::initializer_list Grid_data{TypeParam(0.884040f), TypeParam(-0.825214f), TypeParam(0.496720f), TypeParam(-0.440955f), TypeParam(1.195811f), TypeParam(0.169268f), TypeParam(-1.042100f), TypeParam(0.206524f), TypeParam(0.145895f), TypeParam(-1.160650f), TypeParam(0.240829f), TypeParam(1.144915f), TypeParam(0.345332f), TypeParam(-0.006382f), TypeParam(-0.248763f), TypeParam(0.318888f), TypeParam(-0.534619f), TypeParam(1.181719f), TypeParam(1.037350f), TypeParam(0.560600f), TypeParam(-0.446974f), TypeParam(-1.126746f), TypeParam(-0.690807f), TypeParam(1.166754f), TypeParam(-1.101454f), TypeParam(-1.145775f), TypeParam(-0.086488f), TypeParam(0.381780f), TypeParam(-1.194351f), TypeParam(-1.114106f), TypeParam(0.006524f), TypeParam(-0.402521f), TypeParam(0.836016f), TypeParam(0.344533f), TypeParam(-1.041627f), TypeParam(-1.081571f), TypeParam(0.824102f), TypeParam(-0.212785f), TypeParam(-0.524949f), TypeParam(0.377977f), TypeParam(-0.235842f), TypeParam(0.573897f), TypeParam(0.304308f), TypeParam(-0.519568f), TypeParam(-0.961787f), TypeParam(0.649611f), TypeParam(-0.720973f), TypeParam(-0.132725f), TypeParam(0.164074f), TypeParam(-0.698360f), TypeParam(0.653669f), TypeParam(-0.844065f), TypeParam(0.294728f), TypeParam(0.128341f), TypeParam(0.440293f), TypeParam(-1.177701f), TypeParam(0.069319f), TypeParam(0.585007f), TypeParam(-0.768260f), TypeParam(0.296941f), TypeParam(0.004702f), TypeParam(1.018020f), TypeParam(-0.254096f), TypeParam(0.008198f), TypeParam(-0.521925f), TypeParam(-0.295744f), TypeParam(0.343532f), TypeParam(-1.157334f), TypeParam(0.910329f), TypeParam(0.862921f), TypeParam(0.508195f), TypeParam(0.898317f), TypeParam(-0.373544f), TypeParam(0.273330f), TypeParam(0.061050f), TypeParam(-0.829794f), TypeParam(-0.461335f), TypeParam(-0.426012f), TypeParam(-0.296704f), TypeParam(-1.065526f), TypeParam(-0.843948f), TypeParam(-0.113955f), TypeParam(-0.182548f), TypeParam(-1.089296f), TypeParam(0.256401f), TypeParam(0.653393f), TypeParam(0.999377f), TypeParam(1.009925f), TypeParam(-0.838519f), TypeParam(-0.384579f), TypeParam(-0.569276f), TypeParam(0.220093f), TypeParam(0.321562f), TypeParam(0.266984f), TypeParam(0.701244f), TypeParam(0.633093f), TypeParam(-0.644096f), TypeParam(0.823778f), TypeParam(0.809482f), TypeParam(0.158802f), TypeParam(-1.044029f), TypeParam(-0.735991f), TypeParam(0.334411f), TypeParam(0.414891f), TypeParam(1.118940f), TypeParam(0.610743f), TypeParam(0.434932f), TypeParam(-0.040928f)}; std::initializer_list Y_shape{2, 2, 3, 3, 2}; - std::initializer_list Y_data{0.222880f, -0.137918f, 0.042779f, 0.027606f, 0.146833f, 0.119531f, 0.062001f, 0.077615f, -0.124874f, -0.020856f, 0.248748f, -0.050235f, -0.185885f, -0.124030f, -0.148987f, -0.345107f, 0.753440f, -0.055873f, 0.674388f, 0.063018f, -0.054480f, -0.034452f, 0.780917f, 0.193151f, -0.140647f, -0.047364f, -0.095816f, -0.046983f, 0.254384f, -0.123703f, 0.191358f, 0.674903f, -0.311971f, 1.032054f, 0.672506f, 0.009147f, 0.281933f, 0.135835f, -0.145082f, -0.392560f, -0.229593f, -0.632284f, -0.936929f, -0.916689f, -0.502247f, -0.108609f, -0.645451f, 0.242939f, -0.165902f, -1.220095f, -0.015084f, -0.300940f, -0.352557f, -0.886474f, 0.109150f, 0.398365f, 0.235757f, 0.358618f, 0.082189f, 0.268617f, 0.077955f, -0.157573f, 0.023048f, -0.346908f, 0.360128f, 0.389098f, 0.122882f, 0.675956f, 0.735857f, 0.354858f, 0.244544f, 0.631102f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(0.222880f), TypeParam(-0.137918f), TypeParam(0.042779f), TypeParam(0.027606f), TypeParam(0.146833f), TypeParam(0.119531f), TypeParam(0.062001f), TypeParam(0.077615f), TypeParam(-0.124874f), TypeParam(-0.020856f), TypeParam(0.248748f), TypeParam(-0.050235f), TypeParam(-0.185885f), TypeParam(-0.124030f), TypeParam(-0.148987f), TypeParam(-0.345107f), TypeParam(0.753440f), TypeParam(-0.055873f), TypeParam(0.674388f), TypeParam(0.063018f), TypeParam(-0.054480f), TypeParam(-0.034452f), TypeParam(0.780917f), TypeParam(0.193151f), TypeParam(-0.140647f), TypeParam(-0.047364f), TypeParam(-0.095816f), TypeParam(-0.046983f), TypeParam(0.254384f), TypeParam(-0.123703f), TypeParam(0.191358f), TypeParam(0.674903f), TypeParam(-0.311971f), TypeParam(1.032054f), TypeParam(0.672506f), TypeParam(0.009147f), TypeParam(0.281933f), TypeParam(0.135835f), TypeParam(-0.145082f), TypeParam(-0.392560f), TypeParam(-0.229593f), TypeParam(-0.632284f), TypeParam(-0.936929f), TypeParam(-0.916689f), TypeParam(-0.502247f), TypeParam(-0.108609f), TypeParam(-0.645451f), TypeParam(0.242939f), TypeParam(-0.165902f), TypeParam(-1.220095f), TypeParam(-0.015084f), TypeParam(-0.300940f), TypeParam(-0.352557f), TypeParam(-0.886474f), TypeParam(0.109150f), TypeParam(0.398365f), TypeParam(0.235757f), TypeParam(0.358618f), TypeParam(0.082189f), TypeParam(0.268617f), TypeParam(0.077955f), TypeParam(-0.157573f), TypeParam(0.023048f), TypeParam(-0.346908f), TypeParam(0.360128f), TypeParam(0.389098f), TypeParam(0.122882f), TypeParam(0.675956f), TypeParam(0.735857f), TypeParam(0.354858f), TypeParam(0.244544f), TypeParam(0.631102f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_border_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_border_align_corners) { OpTester test("GridSample", 20); std::string mode = "linear"; std::string padding_mode = "border"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-1.916003f, 0.150784f, -0.179898f, 0.402727f, -0.549764f, 1.772484f, 1.014343f, 0.502823f, 0.976771f, -0.071957f, 0.519875f, 0.408665f, 1.435640f, -0.807775f, -0.181661f, -0.574026f, -0.335351f, -0.155602f, 0.348749f, 1.055618f, 0.737784f, -0.394725f, 0.597608f, 0.006105f}; + std::initializer_list X_data{TypeParam(-1.916003f), TypeParam(0.150784f), TypeParam(-0.179898f), TypeParam(0.402727f), TypeParam(-0.549764f), TypeParam(1.772484f), TypeParam(1.014343f), TypeParam(0.502823f), TypeParam(0.976771f), TypeParam(-0.071957f), TypeParam(0.519875f), TypeParam(0.408665f), TypeParam(1.435640f), TypeParam(-0.807775f), TypeParam(-0.181661f), TypeParam(-0.574026f), TypeParam(-0.335351f), TypeParam(-0.155602f), TypeParam(0.348749f), TypeParam(1.055618f), TypeParam(0.737784f), TypeParam(-0.394725f), TypeParam(0.597608f), TypeParam(0.006105f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{-0.189838f, -1.050410f, -1.072351f, -0.930754f, -0.502573f, 0.186642f, -0.564332f, -0.042774f, -0.143740f, 1.097448f, -0.547044f, 1.127440f, -0.921224f, -1.001202f, 0.390232f, -0.698394f, 0.615509f, -0.663897f, 0.944958f, 1.161950f, 0.076823f, 0.256464f, 1.118784f, 0.711380f}; + std::initializer_list Grid_data{TypeParam(-0.189838f), TypeParam(-1.050410f), TypeParam(-1.072351f), TypeParam(-0.930754f), TypeParam(-0.502573f), TypeParam(0.186642f), TypeParam(-0.564332f), TypeParam(-0.042774f), TypeParam(-0.143740f), TypeParam(1.097448f), TypeParam(-0.547044f), TypeParam(1.127440f), TypeParam(-0.921224f), TypeParam(-1.001202f), TypeParam(0.390232f), TypeParam(-0.698394f), TypeParam(0.615509f), TypeParam(-0.663897f), TypeParam(0.944958f), TypeParam(1.161950f), TypeParam(0.076823f), TypeParam(0.256464f), TypeParam(1.118784f), TypeParam(0.711380f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-1.078787f, -1.795786f, -0.023270f, -0.113413f, 0.444460f, -0.023826f, 0.807136f, 1.011742f, 0.674182f, 0.754935f, 0.472262f, 0.494688f, 1.347277f, -0.223507f, -0.417529f, -0.160549f, -0.353331f, -0.276367f, 0.376591f, 0.571813f, 0.551111f, 0.022384f, 0.166782f, -0.109583f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-1.078787f), TypeParam(-1.795786f), TypeParam(-0.023270f), TypeParam(-0.113413f), TypeParam(0.444460f), TypeParam(-0.023826f), TypeParam(0.807136f), TypeParam(1.011742f), TypeParam(0.674182f), TypeParam(0.754935f), TypeParam(0.472262f), TypeParam(0.494688f), TypeParam(1.347277f), TypeParam(-0.223507f), TypeParam(-0.417529f), TypeParam(-0.160549f), TypeParam(-0.353331f), TypeParam(-0.276367f), TypeParam(0.376591f), TypeParam(0.571813f), TypeParam(0.551111f), TypeParam(0.022384f), TypeParam(0.166782f), TypeParam(-0.109583f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_border_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_border_align_corners) { OpTester test("GridSample", 20); std::string mode = "linear"; std::string padding_mode = "border"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 3, 2}; - std::initializer_list X_data{-0.332555f, 0.980958f, 0.002632f, -1.976749f, 0.979548f, 1.109773f, -0.534887f, 0.705692f, -0.143637f, -0.600830f, 0.315853f, -0.604687f, -0.300652f, -0.375240f, 0.377196f, -0.140920f, 1.159946f, 2.364598f, 0.320719f, 0.397938f, -0.680097f, -1.201632f, 0.270077f, -0.036712f, -0.972864f, 0.792393f, -1.159168f, -0.016679f, -0.665027f, 0.809646f, -1.684452f, 0.049476f, 0.065748f, 0.279619f, -1.079668f, 0.301309f, 1.010100f, -0.119015f, -0.104838f, 0.916627f, -0.522838f, 0.485269f, -1.221088f, 2.044754f, -0.669823f, 0.128370f, 0.080480f, 0.372679f, -0.046427f, -0.732652f, -0.395790f, 0.012594f, -0.170518f, -0.706783f, -0.862588f, -1.177275f, -1.165262f, 0.914826f, -0.661128f, -0.386656f, -0.599246f, 0.544643f, 0.930679f, -1.146137f, 0.212913f, -0.022433f, 1.692830f, 0.187511f, -0.631569f, -0.311540f, -0.885167f, -0.429959f}; + std::initializer_list X_data{TypeParam(-0.332555f), TypeParam(0.980958f), TypeParam(0.002632f), TypeParam(-1.976749f), TypeParam(0.979548f), TypeParam(1.109773f), TypeParam(-0.534887f), TypeParam(0.705692f), TypeParam(-0.143637f), TypeParam(-0.600830f), TypeParam(0.315853f), TypeParam(-0.604687f), TypeParam(-0.300652f), TypeParam(-0.375240f), TypeParam(0.377196f), TypeParam(-0.140920f), TypeParam(1.159946f), TypeParam(2.364598f), TypeParam(0.320719f), TypeParam(0.397938f), TypeParam(-0.680097f), TypeParam(-1.201632f), TypeParam(0.270077f), TypeParam(-0.036712f), TypeParam(-0.972864f), TypeParam(0.792393f), TypeParam(-1.159168f), TypeParam(-0.016679f), TypeParam(-0.665027f), TypeParam(0.809646f), TypeParam(-1.684452f), TypeParam(0.049476f), TypeParam(0.065748f), TypeParam(0.279619f), TypeParam(-1.079668f), TypeParam(0.301309f), TypeParam(1.010100f), TypeParam(-0.119015f), TypeParam(-0.104838f), TypeParam(0.916627f), TypeParam(-0.522838f), TypeParam(0.485269f), TypeParam(-1.221088f), TypeParam(2.044754f), TypeParam(-0.669823f), TypeParam(0.128370f), TypeParam(0.080480f), TypeParam(0.372679f), TypeParam(-0.046427f), TypeParam(-0.732652f), TypeParam(-0.395790f), TypeParam(0.012594f), TypeParam(-0.170518f), TypeParam(-0.706783f), TypeParam(-0.862588f), TypeParam(-1.177275f), TypeParam(-1.165262f), TypeParam(0.914826f), TypeParam(-0.661128f), TypeParam(-0.386656f), TypeParam(-0.599246f), TypeParam(0.544643f), TypeParam(0.930679f), TypeParam(-1.146137f), TypeParam(0.212913f), TypeParam(-0.022433f), TypeParam(1.692830f), TypeParam(0.187511f), TypeParam(-0.631569f), TypeParam(-0.311540f), TypeParam(-0.885167f), TypeParam(-0.429959f)}; std::initializer_list Grid_shape{2, 3, 3, 2, 3}; - std::initializer_list Grid_data{-0.453992f, 0.394222f, 0.755023f, -0.025610f, 0.658840f, 0.982105f, -0.642922f, -0.265292f, -1.080379f, 0.275464f, 0.855228f, -0.233029f, 0.191483f, 0.383441f, -0.025595f, 0.932929f, 0.174866f, -1.179535f, -0.990943f, -1.188918f, 0.049460f, 0.648682f, -0.158317f, 1.078936f, -0.215883f, 0.245340f, 1.082089f, 0.607310f, -0.038283f, 1.155868f, -0.716957f, 0.446971f, 0.757844f, -0.743030f, -1.127212f, 0.383835f, -0.455267f, -0.605570f, 0.238686f, -0.870514f, 1.079285f, -0.107719f, -0.384303f, 1.003178f, 0.334130f, 0.228627f, -0.573757f, 1.143690f, -0.365482f, 0.998076f, -0.088210f, 0.601965f, 0.843747f, -0.893403f, -0.799804f, -1.186625f, 0.865515f, 1.031983f, -0.438564f, -0.587735f, 0.200868f, 0.646055f, 0.296203f, -0.250092f, -0.763290f, 1.026321f, -0.777136f, -1.159559f, -0.479127f, 0.239290f, 0.446029f, 0.464001f, -0.695158f, -0.460548f, -0.533616f, -0.581111f, -1.010728f, 0.245640f, -0.348981f, -1.155007f, -0.700701f, -0.720655f, -0.517635f, -0.741485f, -0.208103f, 0.430035f, -0.971177f, -0.102798f, -0.345348f, -0.613510f, -0.266458f, -0.508597f, 0.038577f, -0.866220f, 0.227567f, 1.101759f, 0.994334f, -0.538031f, 0.369874f, -1.134245f, 1.010332f, -1.195878f, -1.072351f, -1.077155f, -1.114385f, 0.162516f, -0.317319f, 0.287217f}; + std::initializer_list Grid_data{TypeParam(-0.453992f), TypeParam(0.394222f), TypeParam(0.755023f), TypeParam(-0.025610f), TypeParam(0.658840f), TypeParam(0.982105f), TypeParam(-0.642922f), TypeParam(-0.265292f), TypeParam(-1.080379f), TypeParam(0.275464f), TypeParam(0.855228f), TypeParam(-0.233029f), TypeParam(0.191483f), TypeParam(0.383441f), TypeParam(-0.025595f), TypeParam(0.932929f), TypeParam(0.174866f), TypeParam(-1.179535f), TypeParam(-0.990943f), TypeParam(-1.188918f), TypeParam(0.049460f), TypeParam(0.648682f), TypeParam(-0.158317f), TypeParam(1.078936f), TypeParam(-0.215883f), TypeParam(0.245340f), TypeParam(1.082089f), TypeParam(0.607310f), TypeParam(-0.038283f), TypeParam(1.155868f), TypeParam(-0.716957f), TypeParam(0.446971f), TypeParam(0.757844f), TypeParam(-0.743030f), TypeParam(-1.127212f), TypeParam(0.383835f), TypeParam(-0.455267f), TypeParam(-0.605570f), TypeParam(0.238686f), TypeParam(-0.870514f), TypeParam(1.079285f), TypeParam(-0.107719f), TypeParam(-0.384303f), TypeParam(1.003178f), TypeParam(0.334130f), TypeParam(0.228627f), TypeParam(-0.573757f), TypeParam(1.143690f), TypeParam(-0.365482f), TypeParam(0.998076f), TypeParam(-0.088210f), TypeParam(0.601965f), TypeParam(0.843747f), TypeParam(-0.893403f), TypeParam(-0.799804f), TypeParam(-1.186625f), TypeParam(0.865515f), TypeParam(1.031983f), TypeParam(-0.438564f), TypeParam(-0.587735f), TypeParam(0.200868f), TypeParam(0.646055f), TypeParam(0.296203f), TypeParam(-0.250092f), TypeParam(-0.763290f), TypeParam(1.026321f), TypeParam(-0.777136f), TypeParam(-1.159559f), TypeParam(-0.479127f), TypeParam(0.239290f), TypeParam(0.446029f), TypeParam(0.464001f), TypeParam(-0.695158f), TypeParam(-0.460548f), TypeParam(-0.533616f), TypeParam(-0.581111f), TypeParam(-1.010728f), TypeParam(0.245640f), TypeParam(-0.348981f), TypeParam(-1.155007f), TypeParam(-0.700701f), TypeParam(-0.720655f), TypeParam(-0.517635f), TypeParam(-0.741485f), TypeParam(-0.208103f), TypeParam(0.430035f), TypeParam(-0.971177f), TypeParam(-0.102798f), TypeParam(-0.345348f), TypeParam(-0.613510f), TypeParam(-0.266458f), TypeParam(-0.508597f), TypeParam(0.038577f), TypeParam(-0.866220f), TypeParam(0.227567f), TypeParam(1.101759f), TypeParam(0.994334f), TypeParam(-0.538031f), TypeParam(0.369874f), TypeParam(-1.134245f), TypeParam(1.010332f), TypeParam(-1.195878f), TypeParam(-1.072351f), TypeParam(-1.077155f), TypeParam(-1.114385f), TypeParam(0.162516f), TypeParam(-0.317319f), TypeParam(0.287217f)}; std::initializer_list Y_shape{2, 2, 3, 3, 2}; - std::initializer_list Y_data{0.517362f, 1.168304f, -0.283719f, -0.056944f, -0.345007f, -1.383013f, -0.517978f, -0.099340f, 0.531814f, -0.051495f, 0.570203f, -0.350444f, -0.195512f, 0.335075f, 0.533103f, -0.173681f, 0.110927f, 0.549661f, -0.303447f, -0.209369f, -0.479343f, 0.113517f, -0.222508f, -0.981697f, -1.000072f, 0.163343f, -0.019158f, 0.217390f, -0.442252f, -1.020732f, -0.645033f, -0.481248f, -0.359233f, -0.271288f, -0.165768f, -0.092544f, -0.219889f, 0.671201f, -0.041137f, -0.289275f, -0.022793f, -0.130253f, -0.072692f, -0.451858f, 0.402947f, 0.168711f, 0.110811f, 0.202315f, -0.200036f, -0.331588f, 0.583341f, -0.522838f, 1.010100f, -0.018650f, 1.269564f, -0.168394f, -0.209390f, 0.740205f, -0.675828f, -0.325915f, -0.404694f, 0.067064f, -0.744102f, -0.639736f, -0.416580f, -0.317643f, 0.004590f, -0.665815f, -0.163600f, -0.661128f, -0.862588f, -0.132515f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(0.517362f), TypeParam(1.168304f), TypeParam(-0.283719f), TypeParam(-0.056944f), TypeParam(-0.345007f), TypeParam(-1.383013f), TypeParam(-0.517978f), TypeParam(-0.099340f), TypeParam(0.531814f), TypeParam(-0.051495f), TypeParam(0.570203f), TypeParam(-0.350444f), TypeParam(-0.195512f), TypeParam(0.335075f), TypeParam(0.533103f), TypeParam(-0.173681f), TypeParam(0.110927f), TypeParam(0.549661f), TypeParam(-0.303447f), TypeParam(-0.209369f), TypeParam(-0.479343f), TypeParam(0.113517f), TypeParam(-0.222508f), TypeParam(-0.981697f), TypeParam(-1.000072f), TypeParam(0.163343f), TypeParam(-0.019158f), TypeParam(0.217390f), TypeParam(-0.442252f), TypeParam(-1.020732f), TypeParam(-0.645033f), TypeParam(-0.481248f), TypeParam(-0.359233f), TypeParam(-0.271288f), TypeParam(-0.165768f), TypeParam(-0.092544f), TypeParam(-0.219889f), TypeParam(0.671201f), TypeParam(-0.041137f), TypeParam(-0.289275f), TypeParam(-0.022793f), TypeParam(-0.130253f), TypeParam(-0.072692f), TypeParam(-0.451858f), TypeParam(0.402947f), TypeParam(0.168711f), TypeParam(0.110811f), TypeParam(0.202315f), TypeParam(-0.200036f), TypeParam(-0.331588f), TypeParam(0.583341f), TypeParam(-0.522838f), TypeParam(1.010100f), TypeParam(-0.018650f), TypeParam(1.269564f), TypeParam(-0.168394f), TypeParam(-0.209390f), TypeParam(0.740205f), TypeParam(-0.675828f), TypeParam(-0.325915f), TypeParam(-0.404694f), TypeParam(0.067064f), TypeParam(-0.744102f), TypeParam(-0.639736f), TypeParam(-0.416580f), TypeParam(-0.317643f), TypeParam(0.004590f), TypeParam(-0.665815f), TypeParam(-0.163600f), TypeParam(-0.661128f), TypeParam(-0.862588f), TypeParam(-0.132515f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_border_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_border_no_align_corners) { OpTester test("GridSample", 20); std::string mode = "linear"; std::string padding_mode = "border"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-0.050553f, -0.825690f, -0.616085f, 0.337113f, 0.370334f, -0.105073f, -0.565382f, 0.396842f, -0.373193f, -0.780451f, -1.932970f, 1.104960f, -2.569945f, 0.661190f, -0.192302f, 0.734279f, 0.351872f, -1.068136f, 0.173665f, -0.778153f, -0.981877f, 1.485344f, 0.431733f, 0.428167f}; + std::initializer_list X_data{TypeParam(-0.050553f), TypeParam(-0.825690f), TypeParam(-0.616085f), TypeParam(0.337113f), TypeParam(0.370334f), TypeParam(-0.105073f), TypeParam(-0.565382f), TypeParam(0.396842f), TypeParam(-0.373193f), TypeParam(-0.780451f), TypeParam(-1.932970f), TypeParam(1.104960f), TypeParam(-2.569945f), TypeParam(0.661190f), TypeParam(-0.192302f), TypeParam(0.734279f), TypeParam(0.351872f), TypeParam(-1.068136f), TypeParam(0.173665f), TypeParam(-0.778153f), TypeParam(-0.981877f), TypeParam(1.485344f), TypeParam(0.431733f), TypeParam(0.428167f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{-0.330875f, 0.589988f, 0.011588f, -1.144325f, -1.038357f, 0.435055f, -1.053243f, -0.957144f, -0.715458f, 1.143742f, -0.341215f, -0.494762f, -0.810255f, 0.767649f, -0.193763f, 0.231402f, 0.286668f, 0.338432f, 0.768106f, 0.062272f, 0.124125f, -0.077928f, -0.932481f, -0.274618f}; + std::initializer_list Grid_data{TypeParam(-0.330875f), TypeParam(0.589988f), TypeParam(0.011588f), TypeParam(-1.144325f), TypeParam(-1.038357f), TypeParam(0.435055f), TypeParam(-1.053243f), TypeParam(-0.957144f), TypeParam(-0.715458f), TypeParam(1.143742f), TypeParam(-0.341215f), TypeParam(-0.494762f), TypeParam(-0.810255f), TypeParam(0.767649f), TypeParam(-0.193763f), TypeParam(0.231402f), TypeParam(0.286668f), TypeParam(0.338432f), TypeParam(0.768106f), TypeParam(0.062272f), TypeParam(0.124125f), TypeParam(-0.077928f), TypeParam(-0.932481f), TypeParam(-0.274618f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{0.204265f, -0.447104f, 0.027635f, -0.050553f, 0.370334f, -0.248695f, -1.306797f, -0.073120f, -1.391077f, -0.565382f, -1.932970f, -0.419110f, 0.351872f, 0.030903f, -0.124253f, 0.565919f, 0.276202f, -1.171718f, 0.431733f, 0.001712f, 0.689913f, 1.386595f, 0.443614f, -0.505878f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(0.204265f), TypeParam(-0.447104f), TypeParam(0.027635f), TypeParam(-0.050553f), TypeParam(0.370334f), TypeParam(-0.248695f), TypeParam(-1.306797f), TypeParam(-0.073120f), TypeParam(-1.391077f), TypeParam(-0.565382f), TypeParam(-1.932970f), TypeParam(-0.419110f), TypeParam(0.351872f), TypeParam(0.030903f), TypeParam(-0.124253f), TypeParam(0.565919f), TypeParam(0.276202f), TypeParam(-1.171718f), TypeParam(0.431733f), TypeParam(0.001712f), TypeParam(0.689913f), TypeParam(1.386595f), TypeParam(0.443614f), TypeParam(-0.505878f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_border_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_border_no_align_corners) { OpTester test("GridSample", 20); std::string mode = "linear"; std::string padding_mode = "border"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 3, 2}; - std::initializer_list X_data{-0.727099f, 0.057663f, -0.548384f, 0.078163f, -0.133679f, 0.211872f, 0.271687f, -1.221973f, -2.630687f, -0.558102f, -0.327183f, 0.039894f, 1.222102f, 0.144418f, 0.696676f, -2.231791f, 0.910544f, 2.749837f, -0.354036f, -0.106102f, 2.453576f, 0.332319f, -1.743712f, 1.416859f, 0.260041f, -1.179930f, 0.407328f, 0.375476f, 2.028488f, 0.174825f, -1.467126f, 0.079045f, 0.870076f, -0.895165f, 0.631429f, 0.358222f, 1.484120f, -0.622331f, 0.727481f, 0.644213f, 1.299103f, -0.378573f, 1.360908f, 0.905514f, 0.180065f, 0.972162f, 1.246238f, -0.537204f, -1.241497f, -0.772822f, -0.149044f, -1.642060f, 0.120091f, 0.937023f, 0.422106f, 0.652040f, 0.045585f, -1.089530f, 0.356099f, 0.536075f, -1.840257f, -1.035736f, 0.348653f, 0.187942f, 0.150011f, 0.521798f, 1.271739f, 0.977495f, 0.811927f, 0.641729f, 0.964401f, -0.693074f}; + std::initializer_list X_data{TypeParam(-0.727099f), TypeParam(0.057663f), TypeParam(-0.548384f), TypeParam(0.078163f), TypeParam(-0.133679f), TypeParam(0.211872f), TypeParam(0.271687f), TypeParam(-1.221973f), TypeParam(-2.630687f), TypeParam(-0.558102f), TypeParam(-0.327183f), TypeParam(0.039894f), TypeParam(1.222102f), TypeParam(0.144418f), TypeParam(0.696676f), TypeParam(-2.231791f), TypeParam(0.910544f), TypeParam(2.749837f), TypeParam(-0.354036f), TypeParam(-0.106102f), TypeParam(2.453576f), TypeParam(0.332319f), TypeParam(-1.743712f), TypeParam(1.416859f), TypeParam(0.260041f), TypeParam(-1.179930f), TypeParam(0.407328f), TypeParam(0.375476f), TypeParam(2.028488f), TypeParam(0.174825f), TypeParam(-1.467126f), TypeParam(0.079045f), TypeParam(0.870076f), TypeParam(-0.895165f), TypeParam(0.631429f), TypeParam(0.358222f), TypeParam(1.484120f), TypeParam(-0.622331f), TypeParam(0.727481f), TypeParam(0.644213f), TypeParam(1.299103f), TypeParam(-0.378573f), TypeParam(1.360908f), TypeParam(0.905514f), TypeParam(0.180065f), TypeParam(0.972162f), TypeParam(1.246238f), TypeParam(-0.537204f), TypeParam(-1.241497f), TypeParam(-0.772822f), TypeParam(-0.149044f), TypeParam(-1.642060f), TypeParam(0.120091f), TypeParam(0.937023f), TypeParam(0.422106f), TypeParam(0.652040f), TypeParam(0.045585f), TypeParam(-1.089530f), TypeParam(0.356099f), TypeParam(0.536075f), TypeParam(-1.840257f), TypeParam(-1.035736f), TypeParam(0.348653f), TypeParam(0.187942f), TypeParam(0.150011f), TypeParam(0.521798f), TypeParam(1.271739f), TypeParam(0.977495f), TypeParam(0.811927f), TypeParam(0.641729f), TypeParam(0.964401f), TypeParam(-0.693074f)}; std::initializer_list Grid_shape{2, 3, 3, 2, 3}; - std::initializer_list Grid_data{1.017692f, -0.818194f, 0.525611f, -0.556812f, -0.124601f, 1.120205f, 0.153552f, -1.144168f, 1.103147f, -0.050771f, -0.600881f, -0.633732f, 1.029039f, 0.020253f, 0.662802f, 0.788674f, -0.465758f, 0.101853f, -0.776226f, 1.002064f, -0.634553f, 0.797064f, 0.304043f, 0.740241f, -0.845484f, -0.037319f, 0.621792f, -0.047898f, -0.017218f, 0.584766f, -0.896882f, -0.240587f, 0.546590f, 0.588539f, 1.114539f, -0.237379f, 0.284327f, -0.590432f, -0.201402f, -0.602420f, 0.889284f, 0.007310f, 0.488176f, 0.660055f, 0.223618f, 0.127703f, -0.087830f, -1.016490f, 0.193341f, -0.265853f, -1.008634f, 1.118021f, -0.127930f, -0.598904f, -1.168221f, -1.105256f, 0.456964f, -0.547805f, -0.518368f, -0.694346f, 0.968648f, -0.288466f, 0.777819f, 0.952657f, -0.930362f, 0.895254f, -0.229149f, 1.149323f, 0.612939f, -1.162419f, 0.222934f, 0.421831f, -0.435327f, 0.909973f, -0.993750f, -0.380767f, 1.143396f, 1.171977f, 0.599451f, -0.716336f, -1.032482f, -0.975683f, -0.299985f, 0.679795f, 0.379920f, -0.145729f, 1.079221f, 0.942322f, -0.560859f, -0.519668f, -0.014079f, 0.249021f, -0.008590f, 0.463277f, 0.827937f, -0.216375f, 0.589310f, 0.163207f, 0.460623f, 0.494016f, -0.320739f, -0.535032f, 0.512922f, -0.768302f, 0.630003f, -0.769945f, 0.823242f, 0.481487f}; + std::initializer_list Grid_data{TypeParam(1.017692f), TypeParam(-0.818194f), TypeParam(0.525611f), TypeParam(-0.556812f), TypeParam(-0.124601f), TypeParam(1.120205f), TypeParam(0.153552f), TypeParam(-1.144168f), TypeParam(1.103147f), TypeParam(-0.050771f), TypeParam(-0.600881f), TypeParam(-0.633732f), TypeParam(1.029039f), TypeParam(0.020253f), TypeParam(0.662802f), TypeParam(0.788674f), TypeParam(-0.465758f), TypeParam(0.101853f), TypeParam(-0.776226f), TypeParam(1.002064f), TypeParam(-0.634553f), TypeParam(0.797064f), TypeParam(0.304043f), TypeParam(0.740241f), TypeParam(-0.845484f), TypeParam(-0.037319f), TypeParam(0.621792f), TypeParam(-0.047898f), TypeParam(-0.017218f), TypeParam(0.584766f), TypeParam(-0.896882f), TypeParam(-0.240587f), TypeParam(0.546590f), TypeParam(0.588539f), TypeParam(1.114539f), TypeParam(-0.237379f), TypeParam(0.284327f), TypeParam(-0.590432f), TypeParam(-0.201402f), TypeParam(-0.602420f), TypeParam(0.889284f), TypeParam(0.007310f), TypeParam(0.488176f), TypeParam(0.660055f), TypeParam(0.223618f), TypeParam(0.127703f), TypeParam(-0.087830f), TypeParam(-1.016490f), TypeParam(0.193341f), TypeParam(-0.265853f), TypeParam(-1.008634f), TypeParam(1.118021f), TypeParam(-0.127930f), TypeParam(-0.598904f), TypeParam(-1.168221f), TypeParam(-1.105256f), TypeParam(0.456964f), TypeParam(-0.547805f), TypeParam(-0.518368f), TypeParam(-0.694346f), TypeParam(0.968648f), TypeParam(-0.288466f), TypeParam(0.777819f), TypeParam(0.952657f), TypeParam(-0.930362f), TypeParam(0.895254f), TypeParam(-0.229149f), TypeParam(1.149323f), TypeParam(0.612939f), TypeParam(-1.162419f), TypeParam(0.222934f), TypeParam(0.421831f), TypeParam(-0.435327f), TypeParam(0.909973f), TypeParam(-0.993750f), TypeParam(-0.380767f), TypeParam(1.143396f), TypeParam(1.171977f), TypeParam(0.599451f), TypeParam(-0.716336f), TypeParam(-1.032482f), TypeParam(-0.975683f), TypeParam(-0.299985f), TypeParam(0.679795f), TypeParam(0.379920f), TypeParam(-0.145729f), TypeParam(1.079221f), TypeParam(0.942322f), TypeParam(-0.560859f), TypeParam(-0.519668f), TypeParam(-0.014079f), TypeParam(0.249021f), TypeParam(-0.008590f), TypeParam(0.463277f), TypeParam(0.827937f), TypeParam(-0.216375f), TypeParam(0.589310f), TypeParam(0.163207f), TypeParam(0.460623f), TypeParam(0.494016f), TypeParam(-0.320739f), TypeParam(-0.535032f), TypeParam(0.512922f), TypeParam(-0.768302f), TypeParam(0.630003f), TypeParam(-0.769945f), TypeParam(0.823242f), TypeParam(0.481487f)}; std::initializer_list Y_shape{2, 2, 3, 3, 2}; - std::initializer_list Y_data{-0.144687f, 0.794879f, 0.517780f, -0.372025f, -2.071523f, -0.953122f, -0.143000f, 0.040151f, 0.511071f, -0.723342f, 0.441486f, 0.101130f, -0.668215f, -0.313612f, 0.918245f, -0.165560f, -0.141496f, -0.002992f, -0.187333f, 0.433250f, -0.456623f, -0.082449f, -0.849978f, -0.635311f, -1.562003f, -0.323540f, 0.716348f, 0.089914f, 0.085623f, 0.617075f, -0.522245f, 2.013170f, 0.249061f, 0.948093f, 0.518262f, 0.230788f, -0.422900f, 1.315807f, -1.265941f, -0.772822f, 0.375354f, 0.159706f, 1.190603f, 0.217497f, -0.622331f, -0.640623f, -1.324261f, -0.126419f, 0.497220f, -0.421485f, -0.512049f, 0.218454f, -0.680520f, 0.432900f, 0.292848f, 0.338349f, 0.787015f, 0.977495f, 0.494135f, 0.649655f, 0.367739f, 0.766775f, 0.652040f, 1.018832f, 0.738819f, 0.107251f, 0.287288f, 0.515065f, 0.300961f, -0.279154f, 0.866776f, 0.738188f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.144687f), TypeParam(0.794879f), TypeParam(0.517780f), TypeParam(-0.372025f), TypeParam(-2.071523f), TypeParam(-0.953122f), TypeParam(-0.143000f), TypeParam(0.040151f), TypeParam(0.511071f), TypeParam(-0.723342f), TypeParam(0.441486f), TypeParam(0.101130f), TypeParam(-0.668215f), TypeParam(-0.313612f), TypeParam(0.918245f), TypeParam(-0.165560f), TypeParam(-0.141496f), TypeParam(-0.002992f), TypeParam(-0.187333f), TypeParam(0.433250f), TypeParam(-0.456623f), TypeParam(-0.082449f), TypeParam(-0.849978f), TypeParam(-0.635311f), TypeParam(-1.562003f), TypeParam(-0.323540f), TypeParam(0.716348f), TypeParam(0.089914f), TypeParam(0.085623f), TypeParam(0.617075f), TypeParam(-0.522245f), TypeParam(2.013170f), TypeParam(0.249061f), TypeParam(0.948093f), TypeParam(0.518262f), TypeParam(0.230788f), TypeParam(-0.422900f), TypeParam(1.315807f), TypeParam(-1.265941f), TypeParam(-0.772822f), TypeParam(0.375354f), TypeParam(0.159706f), TypeParam(1.190603f), TypeParam(0.217497f), TypeParam(-0.622331f), TypeParam(-0.640623f), TypeParam(-1.324261f), TypeParam(-0.126419f), TypeParam(0.497220f), TypeParam(-0.421485f), TypeParam(-0.512049f), TypeParam(0.218454f), TypeParam(-0.680520f), TypeParam(0.432900f), TypeParam(0.292848f), TypeParam(0.338349f), TypeParam(0.787015f), TypeParam(0.977495f), TypeParam(0.494135f), TypeParam(0.649655f), TypeParam(0.367739f), TypeParam(0.766775f), TypeParam(0.652040f), TypeParam(1.018832f), TypeParam(0.738819f), TypeParam(0.107251f), TypeParam(0.287288f), TypeParam(0.515065f), TypeParam(0.300961f), TypeParam(-0.279154f), TypeParam(0.866776f), TypeParam(0.738188f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_reflection_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_reflection_align_corners) { OpTester test("GridSample", 20); std::string mode = "linear"; std::string padding_mode = "reflection"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-0.599439f, 0.317612f, -0.294302f, -0.530613f, 0.754687f, 0.092241f, -1.009405f, -1.155944f, 0.336327f, 0.159353f, -1.134330f, 0.510271f, 0.271972f, 1.301884f, 1.027400f, 1.193876f, 0.304363f, 1.027256f, 0.186801f, 0.719412f, -0.310900f, -1.123812f, -0.312771f, 2.729156f}; + std::initializer_list X_data{TypeParam(-0.599439f), TypeParam(0.317612f), TypeParam(-0.294302f), TypeParam(-0.530613f), TypeParam(0.754687f), TypeParam(0.092241f), TypeParam(-1.009405f), TypeParam(-1.155944f), TypeParam(0.336327f), TypeParam(0.159353f), TypeParam(-1.134330f), TypeParam(0.510271f), TypeParam(0.271972f), TypeParam(1.301884f), TypeParam(1.027400f), TypeParam(1.193876f), TypeParam(0.304363f), TypeParam(1.027256f), TypeParam(0.186801f), TypeParam(0.719412f), TypeParam(-0.310900f), TypeParam(-1.123812f), TypeParam(-0.312771f), TypeParam(2.729156f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.853801f, 0.833200f, -0.477474f, 0.131677f, 0.571825f, 0.858708f, -1.120796f, 1.194690f, -0.301706f, 0.488934f, -0.745307f, -0.923452f, -0.812682f, 0.707226f, -0.591920f, 0.697573f, 0.362777f, 0.477332f, -0.266909f, -0.379588f, -0.561456f, -0.670762f, 1.106438f, -0.065215f}; + std::initializer_list Grid_data{TypeParam(0.853801f), TypeParam(0.833200f), TypeParam(-0.477474f), TypeParam(0.131677f), TypeParam(0.571825f), TypeParam(0.858708f), TypeParam(-1.120796f), TypeParam(1.194690f), TypeParam(-0.301706f), TypeParam(0.488934f), TypeParam(-0.745307f), TypeParam(-0.923452f), TypeParam(-0.812682f), TypeParam(0.707226f), TypeParam(-0.591920f), TypeParam(0.697573f), TypeParam(0.362777f), TypeParam(0.477332f), TypeParam(-0.266909f), TypeParam(-0.379588f), TypeParam(-0.561456f), TypeParam(-0.670762f), TypeParam(1.106438f), TypeParam(-0.065215f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{0.031577f, -0.232574f, 0.133168f, 0.515460f, 0.063332f, -0.470541f, 0.353729f, 0.159106f, 0.163701f, -0.770097f, -0.133556f, -0.925350f, 0.568498f, 0.636194f, 0.976680f, 0.921805f, 0.684184f, 1.189063f, -0.133022f, 0.070598f, 0.388079f, -0.232737f, 0.042589f, -0.965013f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(0.031577f), TypeParam(-0.232574f), TypeParam(0.133168f), TypeParam(0.515460f), TypeParam(0.063332f), TypeParam(-0.470541f), TypeParam(0.353729f), TypeParam(0.159106f), TypeParam(0.163701f), TypeParam(-0.770097f), TypeParam(-0.133556f), TypeParam(-0.925350f), TypeParam(0.568498f), TypeParam(0.636194f), TypeParam(0.976680f), TypeParam(0.921805f), TypeParam(0.684184f), TypeParam(1.189063f), TypeParam(-0.133022f), TypeParam(0.070598f), TypeParam(0.388079f), TypeParam(-0.232737f), TypeParam(0.042589f), TypeParam(-0.965013f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_reflection_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_reflection_align_corners) { OpTester test("GridSample", 20); std::string mode = "linear"; std::string padding_mode = "reflection"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 3, 2}; - std::initializer_list X_data{-0.441629f, 0.199148f, 1.214051f, -0.000869f, 0.863692f, -0.067719f, -0.621662f, 0.235179f, 0.691041f, 0.176564f, 0.036477f, -0.085879f, 0.785440f, -1.837889f, -0.300151f, -1.710413f, 0.484432f, 2.160478f, -0.049246f, 0.372475f, -1.060470f, -1.000841f, -0.473439f, 0.963055f, 0.174518f, 0.932434f, 0.039338f, -0.343549f, -1.446623f, -0.673622f, 0.520395f, -0.279228f, -0.367065f, -0.871085f, 0.649273f, -0.835047f, 1.063542f, -1.829784f, 1.476173f, -1.048210f, -1.127299f, 1.204756f, -0.998390f, -1.014054f, -1.032717f, 0.977184f, 0.959897f, -0.749289f, 0.784492f, 1.343993f, 1.291144f, 0.099496f, 2.086763f, 0.529948f, -2.296640f, 0.570701f, 0.491216f, -0.003836f, -0.591929f, -0.076994f, 1.239698f, -0.888840f, 0.623497f, 0.769879f, 2.240972f, -2.081689f, 0.798466f, 1.207944f, -0.486804f, -0.488222f, -0.746382f, -0.220282f}; + std::initializer_list X_data{TypeParam(-0.441629f), TypeParam(0.199148f), TypeParam(1.214051f), TypeParam(-0.000869f), TypeParam(0.863692f), TypeParam(-0.067719f), TypeParam(-0.621662f), TypeParam(0.235179f), TypeParam(0.691041f), TypeParam(0.176564f), TypeParam(0.036477f), TypeParam(-0.085879f), TypeParam(0.785440f), TypeParam(-1.837889f), TypeParam(-0.300151f), TypeParam(-1.710413f), TypeParam(0.484432f), TypeParam(2.160478f), TypeParam(-0.049246f), TypeParam(0.372475f), TypeParam(-1.060470f), TypeParam(-1.000841f), TypeParam(-0.473439f), TypeParam(0.963055f), TypeParam(0.174518f), TypeParam(0.932434f), TypeParam(0.039338f), TypeParam(-0.343549f), TypeParam(-1.446623f), TypeParam(-0.673622f), TypeParam(0.520395f), TypeParam(-0.279228f), TypeParam(-0.367065f), TypeParam(-0.871085f), TypeParam(0.649273f), TypeParam(-0.835047f), TypeParam(1.063542f), TypeParam(-1.829784f), TypeParam(1.476173f), TypeParam(-1.048210f), TypeParam(-1.127299f), TypeParam(1.204756f), TypeParam(-0.998390f), TypeParam(-1.014054f), TypeParam(-1.032717f), TypeParam(0.977184f), TypeParam(0.959897f), TypeParam(-0.749289f), TypeParam(0.784492f), TypeParam(1.343993f), TypeParam(1.291144f), TypeParam(0.099496f), TypeParam(2.086763f), TypeParam(0.529948f), TypeParam(-2.296640f), TypeParam(0.570701f), TypeParam(0.491216f), TypeParam(-0.003836f), TypeParam(-0.591929f), TypeParam(-0.076994f), TypeParam(1.239698f), TypeParam(-0.888840f), TypeParam(0.623497f), TypeParam(0.769879f), TypeParam(2.240972f), TypeParam(-2.081689f), TypeParam(0.798466f), TypeParam(1.207944f), TypeParam(-0.486804f), TypeParam(-0.488222f), TypeParam(-0.746382f), TypeParam(-0.220282f)}; std::initializer_list Grid_shape{2, 3, 3, 2, 3}; - std::initializer_list Grid_data{-0.169044f, 0.178997f, 1.112567f, -0.825642f, -0.359793f, 0.170758f, -0.081412f, 0.319486f, 0.630993f, -0.493702f, 0.093438f, 1.085657f, -0.679024f, -0.813753f, -0.920282f, 0.717311f, -1.100678f, -0.583561f, 0.810473f, -0.719377f, 0.975857f, -0.560957f, 0.189840f, 0.157082f, -0.029434f, 0.747413f, 1.019186f, -0.749235f, 0.673000f, 0.320624f, -0.022362f, -0.839050f, 0.355966f, 0.871005f, -1.030007f, -1.108265f, -1.179701f, 0.277273f, -0.344802f, -0.372753f, 1.117390f, -0.306079f, -0.762057f, 0.107942f, -0.658634f, -0.351593f, 0.633875f, 0.276953f, -0.823465f, 1.142446f, 0.811875f, -0.818022f, 0.522699f, 0.493103f, -0.861061f, -0.843352f, -0.993629f, 0.534540f, 0.209070f, 0.507143f, -0.527071f, 0.902309f, 0.153227f, -0.957513f, -0.302041f, 0.612404f, 0.263859f, -0.183579f, -0.838388f, -0.746482f, 1.035039f, -0.687403f, 0.850371f, -0.401659f, 0.011995f, -1.168548f, -0.390077f, 1.011575f, -1.077360f, 0.603794f, -1.009901f, 0.175023f, -1.087964f, -0.949961f, -0.968757f, -0.416100f, 0.163389f, -0.879807f, 0.304124f, 0.722748f, 0.978239f, 1.062535f, 0.790067f, -0.353356f, -0.110591f, 1.061730f, 0.596951f, -0.318231f, 0.905999f, -1.048710f, 1.027042f, 0.671407f, -0.880154f, -0.978736f, 0.938431f, 1.183815f, 0.104716f, -0.468883f}; + std::initializer_list Grid_data{TypeParam(-0.169044f), TypeParam(0.178997f), TypeParam(1.112567f), TypeParam(-0.825642f), TypeParam(-0.359793f), TypeParam(0.170758f), TypeParam(-0.081412f), TypeParam(0.319486f), TypeParam(0.630993f), TypeParam(-0.493702f), TypeParam(0.093438f), TypeParam(1.085657f), TypeParam(-0.679024f), TypeParam(-0.813753f), TypeParam(-0.920282f), TypeParam(0.717311f), TypeParam(-1.100678f), TypeParam(-0.583561f), TypeParam(0.810473f), TypeParam(-0.719377f), TypeParam(0.975857f), TypeParam(-0.560957f), TypeParam(0.189840f), TypeParam(0.157082f), TypeParam(-0.029434f), TypeParam(0.747413f), TypeParam(1.019186f), TypeParam(-0.749235f), TypeParam(0.673000f), TypeParam(0.320624f), TypeParam(-0.022362f), TypeParam(-0.839050f), TypeParam(0.355966f), TypeParam(0.871005f), TypeParam(-1.030007f), TypeParam(-1.108265f), TypeParam(-1.179701f), TypeParam(0.277273f), TypeParam(-0.344802f), TypeParam(-0.372753f), TypeParam(1.117390f), TypeParam(-0.306079f), TypeParam(-0.762057f), TypeParam(0.107942f), TypeParam(-0.658634f), TypeParam(-0.351593f), TypeParam(0.633875f), TypeParam(0.276953f), TypeParam(-0.823465f), TypeParam(1.142446f), TypeParam(0.811875f), TypeParam(-0.818022f), TypeParam(0.522699f), TypeParam(0.493103f), TypeParam(-0.861061f), TypeParam(-0.843352f), TypeParam(-0.993629f), TypeParam(0.534540f), TypeParam(0.209070f), TypeParam(0.507143f), TypeParam(-0.527071f), TypeParam(0.902309f), TypeParam(0.153227f), TypeParam(-0.957513f), TypeParam(-0.302041f), TypeParam(0.612404f), TypeParam(0.263859f), TypeParam(-0.183579f), TypeParam(-0.838388f), TypeParam(-0.746482f), TypeParam(1.035039f), TypeParam(-0.687403f), TypeParam(0.850371f), TypeParam(-0.401659f), TypeParam(0.011995f), TypeParam(-1.168548f), TypeParam(-0.390077f), TypeParam(1.011575f), TypeParam(-1.077360f), TypeParam(0.603794f), TypeParam(-1.009901f), TypeParam(0.175023f), TypeParam(-1.087964f), TypeParam(-0.949961f), TypeParam(-0.968757f), TypeParam(-0.416100f), TypeParam(0.163389f), TypeParam(-0.879807f), TypeParam(0.304124f), TypeParam(0.722748f), TypeParam(0.978239f), TypeParam(1.062535f), TypeParam(0.790067f), TypeParam(-0.353356f), TypeParam(-0.110591f), TypeParam(1.061730f), TypeParam(0.596951f), TypeParam(-0.318231f), TypeParam(0.905999f), TypeParam(-1.048710f), TypeParam(1.027042f), TypeParam(0.671407f), TypeParam(-0.880154f), TypeParam(-0.978736f), TypeParam(0.938431f), TypeParam(1.183815f), TypeParam(0.104716f), TypeParam(-0.468883f)}; std::initializer_list Y_shape{2, 2, 3, 3, 2}; - std::initializer_list Y_data{-0.414201f, 0.167816f, -0.042305f, -0.423495f, -0.101419f, 0.120192f, -1.543294f, 0.344146f, 0.709278f, 0.248721f, -0.269138f, 0.158159f, 0.659876f, 0.226329f, 0.874509f, 0.240959f, 0.412611f, 0.225904f, -0.448580f, 0.057703f, -0.426538f, -0.401142f, -0.147435f, 0.401852f, -0.355426f, -0.286018f, -0.219687f, -0.564205f, 0.282723f, 0.363522f, -0.543706f, -0.787722f, -0.692217f, -0.594894f, 0.091005f, -0.328214f, 0.919003f, 0.408116f, 0.631220f, 0.303619f, -0.197801f, -0.308153f, 0.094457f, 1.027881f, -0.077622f, -0.597219f, -0.661449f, 0.947805f, 0.279352f, 0.828246f, 0.571205f, 1.646163f, 0.714257f, 0.049881f, -1.680014f, -0.056047f, 0.892393f, 0.250564f, 0.138843f, 0.178706f, 0.161286f, 0.036891f, -0.141908f, -0.510903f, 0.733949f, -0.112944f, -0.581858f, -0.269439f, 0.056781f, 0.200325f, 0.814038f, 0.277386f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.414201f), TypeParam(0.167816f), TypeParam(-0.042305f), TypeParam(-0.423495f), TypeParam(-0.101419f), TypeParam(0.120192f), TypeParam(-1.543294f), TypeParam(0.344146f), TypeParam(0.709278f), TypeParam(0.248721f), TypeParam(-0.269138f), TypeParam(0.158159f), TypeParam(0.659876f), TypeParam(0.226329f), TypeParam(0.874509f), TypeParam(0.240959f), TypeParam(0.412611f), TypeParam(0.225904f), TypeParam(-0.448580f), TypeParam(0.057703f), TypeParam(-0.426538f), TypeParam(-0.401142f), TypeParam(-0.147435f), TypeParam(0.401852f), TypeParam(-0.355426f), TypeParam(-0.286018f), TypeParam(-0.219687f), TypeParam(-0.564205f), TypeParam(0.282723f), TypeParam(0.363522f), TypeParam(-0.543706f), TypeParam(-0.787722f), TypeParam(-0.692217f), TypeParam(-0.594894f), TypeParam(0.091005f), TypeParam(-0.328214f), TypeParam(0.919003f), TypeParam(0.408116f), TypeParam(0.631220f), TypeParam(0.303619f), TypeParam(-0.197801f), TypeParam(-0.308153f), TypeParam(0.094457f), TypeParam(1.027881f), TypeParam(-0.077622f), TypeParam(-0.597219f), TypeParam(-0.661449f), TypeParam(0.947805f), TypeParam(0.279352f), TypeParam(0.828246f), TypeParam(0.571205f), TypeParam(1.646163f), TypeParam(0.714257f), TypeParam(0.049881f), TypeParam(-1.680014f), TypeParam(-0.056047f), TypeParam(0.892393f), TypeParam(0.250564f), TypeParam(0.138843f), TypeParam(0.178706f), TypeParam(0.161286f), TypeParam(0.036891f), TypeParam(-0.141908f), TypeParam(-0.510903f), TypeParam(0.733949f), TypeParam(-0.112944f), TypeParam(-0.581858f), TypeParam(-0.269439f), TypeParam(0.056781f), TypeParam(0.200325f), TypeParam(0.814038f), TypeParam(0.277386f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_reflection_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_reflection_no_align_corners) { OpTester test("GridSample", 20); std::string mode = "linear"; std::string padding_mode = "reflection"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-0.173652f, -1.513725f, -0.704586f, -1.952375f, -0.699404f, -0.806298f, 1.640852f, -0.138969f, -0.695411f, -1.352111f, 0.568797f, -0.564294f, -0.056468f, 0.641604f, -0.438370f, 0.450167f, -1.091401f, 1.669729f, -0.908544f, 0.244467f, 0.172109f, 1.156741f, -0.617128f, 1.155460f}; + std::initializer_list X_data{TypeParam(-0.173652f), TypeParam(-1.513725f), TypeParam(-0.704586f), TypeParam(-1.952375f), TypeParam(-0.699404f), TypeParam(-0.806298f), TypeParam(1.640852f), TypeParam(-0.138969f), TypeParam(-0.695411f), TypeParam(-1.352111f), TypeParam(0.568797f), TypeParam(-0.564294f), TypeParam(-0.056468f), TypeParam(0.641604f), TypeParam(-0.438370f), TypeParam(0.450167f), TypeParam(-1.091401f), TypeParam(1.669729f), TypeParam(-0.908544f), TypeParam(0.244467f), TypeParam(0.172109f), TypeParam(1.156741f), TypeParam(-0.617128f), TypeParam(1.155460f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.252250f, -0.151452f, 0.824706f, -0.588292f, -0.591147f, -0.155082f, -0.732938f, 0.457493f, -0.439559f, 0.492330f, 0.696447f, 0.700722f, -0.220298f, 0.654884f, -0.635434f, -1.195619f, -0.114204f, -0.870080f, -0.929674f, 0.305035f, 1.025429f, -0.472240f, -0.067881f, -0.869393f}; + std::initializer_list Grid_data{TypeParam(0.252250f), TypeParam(-0.151452f), TypeParam(0.824706f), TypeParam(-0.588292f), TypeParam(-0.591147f), TypeParam(-0.155082f), TypeParam(-0.732938f), TypeParam(0.457493f), TypeParam(-0.439559f), TypeParam(0.492330f), TypeParam(0.696447f), TypeParam(0.700722f), TypeParam(-0.220298f), TypeParam(0.654884f), TypeParam(-0.635434f), TypeParam(-1.195619f), TypeParam(-0.114204f), TypeParam(-0.870080f), TypeParam(-0.929674f), TypeParam(0.305035f), TypeParam(1.025429f), TypeParam(-0.472240f), TypeParam(-0.067881f), TypeParam(-0.869393f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-1.538390f, -1.565293f, -0.581079f, -0.701030f, -0.725252f, -0.806298f, -0.850602f, -0.281588f, -0.151944f, 0.172138f, 0.177246f, -0.564294f, -0.316822f, -0.056468f, 0.212846f, -0.737167f, 0.585773f, 0.245182f, -0.111277f, -0.908544f, -0.463717f, -0.189009f, 0.510522f, -0.410307f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-1.538390f), TypeParam(-1.565293f), TypeParam(-0.581079f), TypeParam(-0.701030f), TypeParam(-0.725252f), TypeParam(-0.806298f), TypeParam(-0.850602f), TypeParam(-0.281588f), TypeParam(-0.151944f), TypeParam(0.172138f), TypeParam(0.177246f), TypeParam(-0.564294f), TypeParam(-0.316822f), TypeParam(-0.056468f), TypeParam(0.212846f), TypeParam(-0.737167f), TypeParam(0.585773f), TypeParam(0.245182f), TypeParam(-0.111277f), TypeParam(-0.908544f), TypeParam(-0.463717f), TypeParam(-0.189009f), TypeParam(0.510522f), TypeParam(-0.410307f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_reflection_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_reflection_no_align_corners) { OpTester test("GridSample", 20); std::string mode = "linear"; std::string padding_mode = "reflection"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 3, 2}; - std::initializer_list X_data{1.179856f, 1.432512f, 1.016210f, -0.661096f, 0.335863f, 0.565957f, -0.517555f, 2.232456f, -0.615173f, -0.073628f, -0.260768f, -1.952025f, 0.304237f, 0.902323f, -0.485170f, 0.781595f, -1.777093f, -0.274107f, -1.030698f, 0.181435f, 1.947646f, 1.007702f, -0.100718f, 0.154090f, -0.483193f, 1.565921f, -0.932274f, 0.313820f, -0.439116f, -0.411861f, -0.821795f, -1.685022f, -0.013518f, 0.519914f, -0.175407f, -0.507962f, 0.050913f, 0.981904f, 1.087165f, 1.758657f, 0.075954f, -0.481552f, 0.085590f, 0.537831f, -0.419622f, -1.756791f, 1.324879f, -0.267061f, -0.683518f, 0.605393f, 0.041004f, -0.756742f, 0.744950f, -0.508619f, -0.594679f, -1.165646f, -0.699604f, -0.271502f, 0.437731f, -2.206233f, 1.088781f, -0.629873f, -0.904741f, -1.233533f, 2.466710f, -0.117309f, -0.684130f, 0.598811f, 0.288846f, -1.195569f, 0.935300f, 0.962852f}; + std::initializer_list X_data{TypeParam(1.179856f), TypeParam(1.432512f), TypeParam(1.016210f), TypeParam(-0.661096f), TypeParam(0.335863f), TypeParam(0.565957f), TypeParam(-0.517555f), TypeParam(2.232456f), TypeParam(-0.615173f), TypeParam(-0.073628f), TypeParam(-0.260768f), TypeParam(-1.952025f), TypeParam(0.304237f), TypeParam(0.902323f), TypeParam(-0.485170f), TypeParam(0.781595f), TypeParam(-1.777093f), TypeParam(-0.274107f), TypeParam(-1.030698f), TypeParam(0.181435f), TypeParam(1.947646f), TypeParam(1.007702f), TypeParam(-0.100718f), TypeParam(0.154090f), TypeParam(-0.483193f), TypeParam(1.565921f), TypeParam(-0.932274f), TypeParam(0.313820f), TypeParam(-0.439116f), TypeParam(-0.411861f), TypeParam(-0.821795f), TypeParam(-1.685022f), TypeParam(-0.013518f), TypeParam(0.519914f), TypeParam(-0.175407f), TypeParam(-0.507962f), TypeParam(0.050913f), TypeParam(0.981904f), TypeParam(1.087165f), TypeParam(1.758657f), TypeParam(0.075954f), TypeParam(-0.481552f), TypeParam(0.085590f), TypeParam(0.537831f), TypeParam(-0.419622f), TypeParam(-1.756791f), TypeParam(1.324879f), TypeParam(-0.267061f), TypeParam(-0.683518f), TypeParam(0.605393f), TypeParam(0.041004f), TypeParam(-0.756742f), TypeParam(0.744950f), TypeParam(-0.508619f), TypeParam(-0.594679f), TypeParam(-1.165646f), TypeParam(-0.699604f), TypeParam(-0.271502f), TypeParam(0.437731f), TypeParam(-2.206233f), TypeParam(1.088781f), TypeParam(-0.629873f), TypeParam(-0.904741f), TypeParam(-1.233533f), TypeParam(2.466710f), TypeParam(-0.117309f), TypeParam(-0.684130f), TypeParam(0.598811f), TypeParam(0.288846f), TypeParam(-1.195569f), TypeParam(0.935300f), TypeParam(0.962852f)}; std::initializer_list Grid_shape{2, 3, 3, 2, 3}; - std::initializer_list Grid_data{0.625842f, 0.210304f, -0.725943f, -0.553764f, -0.182412f, -0.296478f, -0.254040f, -0.820211f, 0.869312f, 0.622346f, 0.236815f, 0.271706f, 0.140482f, 0.897281f, 0.271537f, 0.182799f, -0.659653f, 0.400310f, -1.122656f, 0.378466f, -1.040147f, -0.496646f, 0.633526f, -0.714734f, 0.955528f, -0.663024f, 1.136629f, 0.369854f, -0.520025f, 0.731855f, -1.062711f, -0.760189f, -0.751812f, 0.157968f, 0.117892f, -1.032129f, 1.157953f, -0.001147f, -0.640796f, 0.028663f, -0.515104f, 0.331070f, 0.434411f, -0.340393f, 0.069958f, 0.714010f, -0.780518f, -0.267586f, -0.177029f, -0.793935f, 0.097737f, 0.044103f, -0.969274f, 0.246164f, 1.145360f, 0.638273f, -0.650926f, 1.098440f, -0.824873f, -0.610135f, 0.529312f, 0.954650f, 1.145143f, 1.033109f, -0.660775f, 0.274592f, -0.753497f, 0.026500f, 0.994206f, 0.590870f, -1.108049f, -0.516447f, -1.012489f, 0.565286f, -0.152334f, -0.877228f, -0.383453f, 0.393797f, 0.111096f, 1.125969f, -0.015932f, 0.377468f, -0.363512f, 0.143194f, 0.042988f, 1.030777f, 0.502813f, -0.683870f, -1.066269f, -1.141727f, -0.435790f, 0.155118f, 1.128919f, -0.117905f, 0.469189f, 0.609870f, -0.919201f, -0.992659f, 0.454699f, 0.559331f, -0.558762f, 0.188050f, -1.174933f, 0.015126f, 0.294147f, 0.011359f, -0.190476f, 0.499476f}; + std::initializer_list Grid_data{TypeParam(0.625842f), TypeParam(0.210304f), TypeParam(-0.725943f), TypeParam(-0.553764f), TypeParam(-0.182412f), TypeParam(-0.296478f), TypeParam(-0.254040f), TypeParam(-0.820211f), TypeParam(0.869312f), TypeParam(0.622346f), TypeParam(0.236815f), TypeParam(0.271706f), TypeParam(0.140482f), TypeParam(0.897281f), TypeParam(0.271537f), TypeParam(0.182799f), TypeParam(-0.659653f), TypeParam(0.400310f), TypeParam(-1.122656f), TypeParam(0.378466f), TypeParam(-1.040147f), TypeParam(-0.496646f), TypeParam(0.633526f), TypeParam(-0.714734f), TypeParam(0.955528f), TypeParam(-0.663024f), TypeParam(1.136629f), TypeParam(0.369854f), TypeParam(-0.520025f), TypeParam(0.731855f), TypeParam(-1.062711f), TypeParam(-0.760189f), TypeParam(-0.751812f), TypeParam(0.157968f), TypeParam(0.117892f), TypeParam(-1.032129f), TypeParam(1.157953f), TypeParam(-0.001147f), TypeParam(-0.640796f), TypeParam(0.028663f), TypeParam(-0.515104f), TypeParam(0.331070f), TypeParam(0.434411f), TypeParam(-0.340393f), TypeParam(0.069958f), TypeParam(0.714010f), TypeParam(-0.780518f), TypeParam(-0.267586f), TypeParam(-0.177029f), TypeParam(-0.793935f), TypeParam(0.097737f), TypeParam(0.044103f), TypeParam(-0.969274f), TypeParam(0.246164f), TypeParam(1.145360f), TypeParam(0.638273f), TypeParam(-0.650926f), TypeParam(1.098440f), TypeParam(-0.824873f), TypeParam(-0.610135f), TypeParam(0.529312f), TypeParam(0.954650f), TypeParam(1.145143f), TypeParam(1.033109f), TypeParam(-0.660775f), TypeParam(0.274592f), TypeParam(-0.753497f), TypeParam(0.026500f), TypeParam(0.994206f), TypeParam(0.590870f), TypeParam(-1.108049f), TypeParam(-0.516447f), TypeParam(-1.012489f), TypeParam(0.565286f), TypeParam(-0.152334f), TypeParam(-0.877228f), TypeParam(-0.383453f), TypeParam(0.393797f), TypeParam(0.111096f), TypeParam(1.125969f), TypeParam(-0.015932f), TypeParam(0.377468f), TypeParam(-0.363512f), TypeParam(0.143194f), TypeParam(0.042988f), TypeParam(1.030777f), TypeParam(0.502813f), TypeParam(-0.683870f), TypeParam(-1.066269f), TypeParam(-1.141727f), TypeParam(-0.435790f), TypeParam(0.155118f), TypeParam(1.128919f), TypeParam(-0.117905f), TypeParam(0.469189f), TypeParam(0.609870f), TypeParam(-0.919201f), TypeParam(-0.992659f), TypeParam(0.454699f), TypeParam(0.559331f), TypeParam(-0.558762f), TypeParam(0.188050f), TypeParam(-1.174933f), TypeParam(0.015126f), TypeParam(0.294147f), TypeParam(0.011359f), TypeParam(-0.190476f), TypeParam(0.499476f)}; std::initializer_list Y_shape{2, 2, 3, 3, 2}; - std::initializer_list Y_data{-0.274014f, 0.145076f, 0.451342f, -0.273219f, -1.128307f, 0.962473f, 0.629978f, 0.370138f, 0.901663f, 0.778787f, 1.179856f, 0.014218f, -0.634683f, 0.585419f, 0.972130f, 1.911376f, 0.389205f, 0.849839f, 0.738424f, 0.054296f, -1.034114f, 0.096287f, -0.408114f, -0.474491f, 0.784791f, 0.001762f, -1.672976f, -1.127656f, -1.030698f, 1.105979f, 0.979492f, -0.258014f, 0.693543f, 1.010218f, -0.008927f, -0.078404f, -0.384825f, 0.944247f, -0.508619f, 0.548774f, 0.068986f, 0.881841f, 0.869967f, -0.274754f, 0.337312f, -0.374188f, 0.161655f, 0.050913f, 0.146763f, 0.119233f, -0.438980f, 0.228062f, -0.187221f, -0.376543f, -2.077576f, -1.120214f, 0.962852f, -0.133462f, 0.314542f, -1.044921f, 1.568017f, -0.060947f, 0.838264f, -0.652863f, 0.978122f, -0.594679f, 0.366536f, 0.596221f, -0.120431f, -0.435362f, -0.328892f, -0.434798f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.274014f), TypeParam(0.145076f), TypeParam(0.451342f), TypeParam(-0.273219f), TypeParam(-1.128307f), TypeParam(0.962473f), TypeParam(0.629978f), TypeParam(0.370138f), TypeParam(0.901663f), TypeParam(0.778787f), TypeParam(1.179856f), TypeParam(0.014218f), TypeParam(-0.634683f), TypeParam(0.585419f), TypeParam(0.972130f), TypeParam(1.911376f), TypeParam(0.389205f), TypeParam(0.849839f), TypeParam(0.738424f), TypeParam(0.054296f), TypeParam(-1.034114f), TypeParam(0.096287f), TypeParam(-0.408114f), TypeParam(-0.474491f), TypeParam(0.784791f), TypeParam(0.001762f), TypeParam(-1.672976f), TypeParam(-1.127656f), TypeParam(-1.030698f), TypeParam(1.105979f), TypeParam(0.979492f), TypeParam(-0.258014f), TypeParam(0.693543f), TypeParam(1.010218f), TypeParam(-0.008927f), TypeParam(-0.078404f), TypeParam(-0.384825f), TypeParam(0.944247f), TypeParam(-0.508619f), TypeParam(0.548774f), TypeParam(0.068986f), TypeParam(0.881841f), TypeParam(0.869967f), TypeParam(-0.274754f), TypeParam(0.337312f), TypeParam(-0.374188f), TypeParam(0.161655f), TypeParam(0.050913f), TypeParam(0.146763f), TypeParam(0.119233f), TypeParam(-0.438980f), TypeParam(0.228062f), TypeParam(-0.187221f), TypeParam(-0.376543f), TypeParam(-2.077576f), TypeParam(-1.120214f), TypeParam(0.962852f), TypeParam(-0.133462f), TypeParam(0.314542f), TypeParam(-1.044921f), TypeParam(1.568017f), TypeParam(-0.060947f), TypeParam(0.838264f), TypeParam(-0.652863f), TypeParam(0.978122f), TypeParam(-0.594679f), TypeParam(0.366536f), TypeParam(0.596221f), TypeParam(-0.120431f), TypeParam(-0.435362f), TypeParam(-0.328892f), TypeParam(-0.434798f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_zeros_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_zeros_align_corners) { OpTester test("GridSample", 20); std::string mode = "cubic"; std::string padding_mode = "zeros"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{0.741614f, -1.612838f, 0.274100f, -0.685296f, -0.032079f, -0.246424f, 0.089412f, -0.776545f, -0.152179f, 0.312533f, -1.503701f, -0.720829f, 0.877575f, 0.407229f, -0.889951f, 0.603605f, -0.140859f, 2.032775f, -0.520668f, 1.063163f, -1.008883f, 0.194195f, -0.303240f, -0.967884f}; + std::initializer_list X_data{TypeParam(0.741614f), TypeParam(-1.612838f), TypeParam(0.274100f), TypeParam(-0.685296f), TypeParam(-0.032079f), TypeParam(-0.246424f), TypeParam(0.089412f), TypeParam(-0.776545f), TypeParam(-0.152179f), TypeParam(0.312533f), TypeParam(-1.503701f), TypeParam(-0.720829f), TypeParam(0.877575f), TypeParam(0.407229f), TypeParam(-0.889951f), TypeParam(0.603605f), TypeParam(-0.140859f), TypeParam(2.032775f), TypeParam(-0.520668f), TypeParam(1.063163f), TypeParam(-1.008883f), TypeParam(0.194195f), TypeParam(-0.303240f), TypeParam(-0.967884f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{-0.932019f, -0.034394f, 0.554511f, 0.484230f, 0.141120f, 0.485083f, -0.836516f, 0.999462f, 0.026764f, 0.775689f, 0.265464f, -0.133497f, 0.514005f, 1.139161f, 1.183700f, -1.010095f, 0.072779f, -0.862052f, 0.699178f, 0.861473f, -0.842637f, -0.069355f, 0.830374f, 0.793568f}; + std::initializer_list Grid_data{TypeParam(-0.932019f), TypeParam(-0.034394f), TypeParam(0.554511f), TypeParam(0.484230f), TypeParam(0.141120f), TypeParam(0.485083f), TypeParam(-0.836516f), TypeParam(0.999462f), TypeParam(0.026764f), TypeParam(0.775689f), TypeParam(0.265464f), TypeParam(-0.133497f), TypeParam(0.514005f), TypeParam(1.139161f), TypeParam(1.183700f), TypeParam(-1.010095f), TypeParam(0.072779f), TypeParam(-0.862052f), TypeParam(0.699178f), TypeParam(0.861473f), TypeParam(-0.842637f), TypeParam(-0.069355f), TypeParam(0.830374f), TypeParam(0.793568f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{0.274192f, -0.348792f, -0.238780f, -0.048938f, -0.195915f, -0.488976f, -0.104505f, -0.351103f, -0.583059f, -1.533095f, -1.141282f, 0.187052f, 1.668728f, 0.345182f, 0.682750f, 1.893112f, -0.775917f, 1.920082f, -0.889375f, 1.071508f, 0.336517f, -0.933740f, -0.981629f, -0.893789f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(0.274192f), TypeParam(-0.348792f), TypeParam(-0.238780f), TypeParam(-0.048938f), TypeParam(-0.195915f), TypeParam(-0.488976f), TypeParam(-0.104505f), TypeParam(-0.351103f), TypeParam(-0.583059f), TypeParam(-1.533095f), TypeParam(-1.141282f), TypeParam(0.187052f), TypeParam(1.668728f), TypeParam(0.345182f), TypeParam(0.682750f), TypeParam(1.893112f), TypeParam(-0.775917f), TypeParam(1.920082f), TypeParam(-0.889375f), TypeParam(1.071508f), TypeParam(0.336517f), TypeParam(-0.933740f), TypeParam(-0.981629f), TypeParam(-0.893789f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_zeros_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_zeros_no_align_corners) { OpTester test("GridSample", 20); std::string mode = "cubic"; std::string padding_mode = "zeros"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{0.333395f, 0.977190f, 0.214232f, 0.363731f, -1.352515f, -0.980304f, -0.354887f, -0.481711f, -0.607915f, -0.309748f, 2.262781f, 0.963363f, 1.997079f, 0.987449f, -0.537662f, 1.011585f, 0.822184f, 0.567108f, 0.135401f, -0.943315f, -0.614181f, 0.030652f, 0.914757f, 0.971777f}; + std::initializer_list X_data{TypeParam(0.333395f), TypeParam(0.977190f), TypeParam(0.214232f), TypeParam(0.363731f), TypeParam(-1.352515f), TypeParam(-0.980304f), TypeParam(-0.354887f), TypeParam(-0.481711f), TypeParam(-0.607915f), TypeParam(-0.309748f), TypeParam(2.262781f), TypeParam(0.963363f), TypeParam(1.997079f), TypeParam(0.987449f), TypeParam(-0.537662f), TypeParam(1.011585f), TypeParam(0.822184f), TypeParam(0.567108f), TypeParam(0.135401f), TypeParam(-0.943315f), TypeParam(-0.614181f), TypeParam(0.030652f), TypeParam(0.914757f), TypeParam(0.971777f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{-0.487111f, 0.913573f, 0.641905f, -0.093110f, 0.512522f, 0.358369f, 0.655341f, -0.964320f, 0.370929f, -1.136512f, -0.789199f, -0.447185f, -0.116915f, -1.132446f, 0.029865f, 0.191588f, -0.476239f, 0.389224f, 1.048588f, -0.204978f, -0.639094f, -1.062994f, -0.876243f, -0.663705f}; + std::initializer_list Grid_data{TypeParam(-0.487111f), TypeParam(0.913573f), TypeParam(0.641905f), TypeParam(-0.093110f), TypeParam(0.512522f), TypeParam(0.358369f), TypeParam(0.655341f), TypeParam(-0.964320f), TypeParam(0.370929f), TypeParam(-1.136512f), TypeParam(-0.789199f), TypeParam(-0.447185f), TypeParam(-0.116915f), TypeParam(-1.132446f), TypeParam(0.029865f), TypeParam(0.191588f), TypeParam(-0.476239f), TypeParam(0.389224f), TypeParam(1.048588f), TypeParam(-0.204978f), TypeParam(-0.639094f), TypeParam(-1.062994f), TypeParam(-0.876243f), TypeParam(-0.663705f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-1.051920f, 0.501832f, -0.508839f, 0.563480f, 0.297178f, 0.246571f, 1.781955f, -0.353574f, 0.481200f, -0.258839f, -0.145200f, -0.469558f, 0.624262f, 0.351267f, 0.180256f, 0.571859f, 0.903895f, 1.383745f, -0.081406f, 0.133665f, 0.348401f, -0.164219f, 0.138237f, 0.203282f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-1.051920f), TypeParam(0.501832f), TypeParam(-0.508839f), TypeParam(0.563480f), TypeParam(0.297178f), TypeParam(0.246571f), TypeParam(1.781955f), TypeParam(-0.353574f), TypeParam(0.481200f), TypeParam(-0.258839f), TypeParam(-0.145200f), TypeParam(-0.469558f), TypeParam(0.624262f), TypeParam(0.351267f), TypeParam(0.180256f), TypeParam(0.571859f), TypeParam(0.903895f), TypeParam(1.383745f), TypeParam(-0.081406f), TypeParam(0.133665f), TypeParam(0.348401f), TypeParam(-0.164219f), TypeParam(0.138237f), TypeParam(0.203282f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_border_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_border_align_corners) { OpTester test("GridSample", 20); std::string mode = "cubic"; std::string padding_mode = "border"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-0.480448f, 0.682093f, 0.237716f, -1.234307f, 2.139750f, 2.410321f, 0.491472f, -0.553422f, 0.032129f, -0.162503f, 0.144036f, -1.889875f, -0.293944f, -1.390146f, -1.552136f, 1.604720f, -1.707202f, 0.182427f, -0.631000f, 0.196649f, 0.427711f, -0.014224f, -1.319834f, -2.703346f}; + std::initializer_list X_data{TypeParam(-0.480448f), TypeParam(0.682093f), TypeParam(0.237716f), TypeParam(-1.234307f), TypeParam(2.139750f), TypeParam(2.410321f), TypeParam(0.491472f), TypeParam(-0.553422f), TypeParam(0.032129f), TypeParam(-0.162503f), TypeParam(0.144036f), TypeParam(-1.889875f), TypeParam(-0.293944f), TypeParam(-1.390146f), TypeParam(-1.552136f), TypeParam(1.604720f), TypeParam(-1.707202f), TypeParam(0.182427f), TypeParam(-0.631000f), TypeParam(0.196649f), TypeParam(0.427711f), TypeParam(-0.014224f), TypeParam(-1.319834f), TypeParam(-2.703346f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.503717f, 0.572989f, 0.179517f, -0.060398f, 0.503876f, 0.288627f, -1.148268f, 0.194010f, -0.532910f, -0.636357f, 0.464076f, 0.245386f, 0.203212f, -0.569260f, 0.554489f, 1.126118f, 0.146805f, 0.493232f, -1.052794f, 0.713394f, 0.416866f, 0.540634f, 0.500415f, -0.315629f}; + std::initializer_list Grid_data{TypeParam(0.503717f), TypeParam(0.572989f), TypeParam(0.179517f), TypeParam(-0.060398f), TypeParam(0.503876f), TypeParam(0.288627f), TypeParam(-1.148268f), TypeParam(0.194010f), TypeParam(-0.532910f), TypeParam(-0.636357f), TypeParam(0.464076f), TypeParam(0.245386f), TypeParam(0.203212f), TypeParam(-0.569260f), TypeParam(0.554489f), TypeParam(1.126118f), TypeParam(0.146805f), TypeParam(0.493232f), TypeParam(-1.052794f), TypeParam(0.713394f), TypeParam(0.416866f), TypeParam(0.540634f), TypeParam(0.500415f), TypeParam(-0.315629f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{0.885659f, -0.722912f, -0.180469f, 0.697015f, -0.322127f, -0.292851f, -0.867861f, -0.047527f, -0.447720f, 0.028100f, 0.191874f, -0.378776f, -0.321888f, -0.277691f, -0.037604f, -1.766707f, 0.320836f, 0.415106f, 0.179209f, -2.609096f, -0.929794f, -0.788240f, -1.212243f, 0.337704f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(0.885659f), TypeParam(-0.722912f), TypeParam(-0.180469f), TypeParam(0.697015f), TypeParam(-0.322127f), TypeParam(-0.292851f), TypeParam(-0.867861f), TypeParam(-0.047527f), TypeParam(-0.447720f), TypeParam(0.028100f), TypeParam(0.191874f), TypeParam(-0.378776f), TypeParam(-0.321888f), TypeParam(-0.277691f), TypeParam(-0.037604f), TypeParam(-1.766707f), TypeParam(0.320836f), TypeParam(0.415106f), TypeParam(0.179209f), TypeParam(-2.609096f), TypeParam(-0.929794f), TypeParam(-0.788240f), TypeParam(-1.212243f), TypeParam(0.337704f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_border_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_border_no_align_corners) { OpTester test("GridSample", 20); std::string mode = "cubic"; std::string padding_mode = "border"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-0.924256f, -2.309784f, 1.272769f, 0.548427f, -1.478527f, -3.472946f, -1.252325f, 0.268589f, 0.326270f, 0.105016f, 0.515184f, -0.951158f, -0.658693f, -2.018776f, 0.981625f, -0.401504f, 1.560519f, -0.129836f, -1.876357f, 0.511516f, -1.825582f, 0.358958f, -0.805392f, -1.409127f}; + std::initializer_list X_data{TypeParam(-0.924256f), TypeParam(-2.309784f), TypeParam(1.272769f), TypeParam(0.548427f), TypeParam(-1.478527f), TypeParam(-3.472946f), TypeParam(-1.252325f), TypeParam(0.268589f), TypeParam(0.326270f), TypeParam(0.105016f), TypeParam(0.515184f), TypeParam(-0.951158f), TypeParam(-0.658693f), TypeParam(-2.018776f), TypeParam(0.981625f), TypeParam(-0.401504f), TypeParam(1.560519f), TypeParam(-0.129836f), TypeParam(-1.876357f), TypeParam(0.511516f), TypeParam(-1.825582f), TypeParam(0.358958f), TypeParam(-0.805392f), TypeParam(-1.409127f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.874856f, -1.090775f, 1.169192f, 0.447098f, 0.583418f, 0.267395f, 0.788144f, 1.129706f, -0.102229f, -0.984624f, 1.101916f, -0.253070f, -0.578731f, 0.738703f, 0.669694f, 0.160659f, -0.075327f, -0.229561f, 1.100291f, 0.731142f, 0.714643f, 0.765214f, -0.628031f, 0.250554f}; + std::initializer_list Grid_data{TypeParam(0.874856f), TypeParam(-1.090775f), TypeParam(1.169192f), TypeParam(0.447098f), TypeParam(0.583418f), TypeParam(0.267395f), TypeParam(0.788144f), TypeParam(1.129706f), TypeParam(-0.102229f), TypeParam(-0.984624f), TypeParam(1.101916f), TypeParam(-0.253070f), TypeParam(-0.578731f), TypeParam(0.738703f), TypeParam(0.669694f), TypeParam(0.160659f), TypeParam(-0.075327f), TypeParam(-0.229561f), TypeParam(1.100291f), TypeParam(0.731142f), TypeParam(0.714643f), TypeParam(0.765214f), TypeParam(-0.628031f), TypeParam(0.250554f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-2.647128f, -2.154235f, -0.768645f, -3.893546f, -1.698376f, -0.114530f, 0.458115f, -0.696657f, -0.370692f, -1.169692f, -0.754730f, 0.320002f, 1.683550f, -0.301499f, -0.176003f, -0.236653f, -0.278257f, 1.480160f, -0.700350f, 0.095525f, -0.891605f, -1.569065f, -1.633715f, -1.535763f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-2.647128f), TypeParam(-2.154235f), TypeParam(-0.768645f), TypeParam(-3.893546f), TypeParam(-1.698376f), TypeParam(-0.114530f), TypeParam(0.458115f), TypeParam(-0.696657f), TypeParam(-0.370692f), TypeParam(-1.169692f), TypeParam(-0.754730f), TypeParam(0.320002f), TypeParam(1.683550f), TypeParam(-0.301499f), TypeParam(-0.176003f), TypeParam(-0.236653f), TypeParam(-0.278257f), TypeParam(1.480160f), TypeParam(-0.700350f), TypeParam(0.095525f), TypeParam(-0.891605f), TypeParam(-1.569065f), TypeParam(-1.633715f), TypeParam(-1.535763f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_reflection_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_reflection_align_corners) { OpTester test("GridSample", 20); std::string mode = "cubic"; std::string padding_mode = "reflection"; int64_t align_corners = 1; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-0.328038f, -0.658850f, -0.054298f, 0.012663f, -0.077366f, 0.644305f, -1.262985f, 0.922028f, 0.189962f, 0.518836f, 1.168413f, -0.286220f, 0.431207f, -0.295352f, -0.357675f, -0.311715f, 0.839514f, -0.651820f, -0.283934f, 0.430508f, 0.206334f, 0.765966f, -1.144732f, -0.507045f}; + std::initializer_list X_data{TypeParam(-0.328038f), TypeParam(-0.658850f), TypeParam(-0.054298f), TypeParam(0.012663f), TypeParam(-0.077366f), TypeParam(0.644305f), TypeParam(-1.262985f), TypeParam(0.922028f), TypeParam(0.189962f), TypeParam(0.518836f), TypeParam(1.168413f), TypeParam(-0.286220f), TypeParam(0.431207f), TypeParam(-0.295352f), TypeParam(-0.357675f), TypeParam(-0.311715f), TypeParam(0.839514f), TypeParam(-0.651820f), TypeParam(-0.283934f), TypeParam(0.430508f), TypeParam(0.206334f), TypeParam(0.765966f), TypeParam(-1.144732f), TypeParam(-0.507045f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{-0.372000f, -1.056863f, -0.360826f, -0.268314f, 0.691035f, -0.595044f, 0.720198f, 0.166462f, -0.201118f, -1.069416f, 1.184721f, -0.213980f, 0.755038f, -0.620722f, -1.168597f, -0.956522f, -0.614982f, -0.382162f, -0.169456f, 1.000817f, -1.106710f, 0.598940f, 1.009714f, 0.007723f}; + std::initializer_list Grid_data{TypeParam(-0.372000f), TypeParam(-1.056863f), TypeParam(-0.360826f), TypeParam(-0.268314f), TypeParam(0.691035f), TypeParam(-0.595044f), TypeParam(0.720198f), TypeParam(0.166462f), TypeParam(-0.201118f), TypeParam(-1.069416f), TypeParam(1.184721f), TypeParam(-0.213980f), TypeParam(0.755038f), TypeParam(-0.620722f), TypeParam(-1.168597f), TypeParam(-0.956522f), TypeParam(-0.614982f), TypeParam(-0.382162f), TypeParam(-0.169456f), TypeParam(1.000817f), TypeParam(-1.106710f), TypeParam(0.598940f), TypeParam(1.009714f), TypeParam(0.007723f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{-0.403118f, -0.158055f, -0.496030f, 0.161379f, -0.440603f, -0.193607f, -0.746082f, -0.076433f, 0.751030f, 0.360851f, -0.488453f, 0.664305f, -0.259139f, 0.411796f, -0.156648f, 0.281569f, 0.437515f, -0.313812f, 0.573781f, -0.265706f, 0.200380f, -0.906155f, -0.724311f, 0.760352f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(-0.403118f), TypeParam(-0.158055f), TypeParam(-0.496030f), TypeParam(0.161379f), TypeParam(-0.440603f), TypeParam(-0.193607f), TypeParam(-0.746082f), TypeParam(-0.076433f), TypeParam(0.751030f), TypeParam(0.360851f), TypeParam(-0.488453f), TypeParam(0.664305f), TypeParam(-0.259139f), TypeParam(0.411796f), TypeParam(-0.156648f), TypeParam(0.281569f), TypeParam(0.437515f), TypeParam(-0.313812f), TypeParam(0.573781f), TypeParam(-0.265706f), TypeParam(0.200380f), TypeParam(-0.906155f), TypeParam(-0.724311f), TypeParam(0.760352f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } -TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_reflection_no_align_corners) { +TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_reflection_no_align_corners) { OpTester test("GridSample", 20); std::string mode = "cubic"; std::string padding_mode = "reflection"; int64_t align_corners = 0; std::initializer_list X_shape{2, 2, 3, 2}; - std::initializer_list X_data{-0.290962f, 0.867797f, -0.085436f, -1.597520f, 0.695524f, 0.838739f, 0.513032f, 0.166242f, -0.546135f, -0.780313f, -0.512993f, -0.449479f, 1.594718f, 0.953375f, 0.692587f, -0.798364f, -0.128799f, -0.456210f, 2.098909f, -1.561220f, 1.713821f, -0.701970f, -0.287280f, -1.708048f}; + std::initializer_list X_data{TypeParam(-0.290962f), TypeParam(0.867797f), TypeParam(-0.085436f), TypeParam(-1.597520f), TypeParam(0.695524f), TypeParam(0.838739f), TypeParam(0.513032f), TypeParam(0.166242f), TypeParam(-0.546135f), TypeParam(-0.780313f), TypeParam(-0.512993f), TypeParam(-0.449479f), TypeParam(1.594718f), TypeParam(0.953375f), TypeParam(0.692587f), TypeParam(-0.798364f), TypeParam(-0.128799f), TypeParam(-0.456210f), TypeParam(2.098909f), TypeParam(-1.561220f), TypeParam(1.713821f), TypeParam(-0.701970f), TypeParam(-0.287280f), TypeParam(-1.708048f)}; std::initializer_list Grid_shape{2, 3, 2, 2}; - std::initializer_list Grid_data{0.934471f, 0.728362f, -0.458301f, -1.040800f, 0.157908f, 0.753451f, -0.122762f, 0.100970f, 0.889432f, 0.495471f, 0.897108f, 0.176205f, 0.134514f, -0.287037f, -0.202498f, -0.637759f, 0.802292f, 1.094459f, 0.445338f, 0.034096f, -0.396126f, -1.184798f, -0.222199f, -0.851887f}; + std::initializer_list Grid_data{TypeParam(0.934471f), TypeParam(0.728362f), TypeParam(-0.458301f), TypeParam(-1.040800f), TypeParam(0.157908f), TypeParam(0.753451f), TypeParam(-0.122762f), TypeParam(0.100970f), TypeParam(0.889432f), TypeParam(0.495471f), TypeParam(0.897108f), TypeParam(0.176205f), TypeParam(0.134514f), TypeParam(-0.287037f), TypeParam(-0.202498f), TypeParam(-0.637759f), TypeParam(0.802292f), TypeParam(1.094459f), TypeParam(0.445338f), TypeParam(0.034096f), TypeParam(-0.396126f), TypeParam(-1.184798f), TypeParam(-0.222199f), TypeParam(-0.851887f)}; std::initializer_list Y_shape{2, 2, 3, 2}; - std::initializer_list Y_data{1.037788f, -0.275160f, 0.953595f, -0.518196f, 0.118127f, -1.525148f, -0.413483f, 0.696689f, -0.450182f, -0.696169f, -0.561886f, -0.828986f, 0.343953f, 1.379632f, -0.417260f, -0.781500f, 1.666511f, 1.599268f, 0.106200f, 1.088396f, -2.079140f, -0.612122f, 1.822402f, 1.173807f}; - test.AddInput("X", X_shape, X_data); - test.AddInput("Grid", Grid_shape, Grid_data); + std::initializer_list Y_data{TypeParam(1.037788f), TypeParam(-0.275160f), TypeParam(0.953595f), TypeParam(-0.518196f), TypeParam(0.118127f), TypeParam(-1.525148f), TypeParam(-0.413483f), TypeParam(0.696689f), TypeParam(-0.450182f), TypeParam(-0.696169f), TypeParam(-0.561886f), TypeParam(-0.828986f), TypeParam(0.343953f), TypeParam(1.379632f), TypeParam(-0.417260f), TypeParam(-0.781500f), TypeParam(1.666511f), TypeParam(1.599268f), TypeParam(0.106200f), TypeParam(1.088396f), TypeParam(-2.079140f), TypeParam(-0.612122f), TypeParam(1.822402f), TypeParam(1.173807f)}; + test.AddInput("X", X_shape, X_data); + test.AddInput("Grid", Grid_shape, Grid_data); test.AddAttribute("mode", mode); test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); - test.AddOutput("Y", Y_shape, Y_data); + test.AddOutput("Y", Y_shape, Y_data); RunTests(test, GetExecutionProviders(20)); } diff --git a/onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py b/onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py index c7e263ca3f654..bf58a5d3fc1d5 100644 --- a/onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py +++ b/onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py @@ -14,6 +14,17 @@ padding_modes = ["zeros", "border", "reflection"] align_corners_options = [True, False] +print( + """ +template +class GridSampleTest : public ::testing::Test { +}; + +using GridSampleTestTypes = ::testing::Types; +TYPED_TEST_SUITE(GridSampleTest, GridSampleTestTypes); + +""" +) # Loop over the combinations of parameters torch.manual_seed(0) for opset_version in [16, 20]: @@ -42,11 +53,15 @@ input_tensor, grid_tensor, mode=mode, padding_mode=padding_mode, align_corners=align_corners ) - X_data_str = "{" + ", ".join([f"{x:.6f}f" for x in input_tensor.numpy().flatten()]) + "}" - Grid_data_str = "{" + ", ".join([f"{x:.6f}f" for x in grid_tensor.numpy().flatten()]) + "}" + X_data_str = "{" + ", ".join([f"TypeParam({x:.6f}f)" for x in input_tensor.numpy().flatten()]) + "}" + Grid_data_str = ( + "{" + ", ".join([f"TypeParam({x:.6f}f)" for x in grid_tensor.numpy().flatten()]) + "}" + ) Y_shape = output_tensor.shape - Y_data_str = "{" + ", ".join([f"{x:.6f}f" for x in output_tensor.numpy().flatten()]) + "}" + Y_data_str = ( + "{" + ", ".join([f"TypeParam({x:.6f}f)" for x in output_tensor.numpy().flatten()]) + "}" + ) onnx_mode = mode if opset_version >= 20: @@ -58,24 +73,25 @@ onnx_align_corners = 1 if align_corners else 0 test_name = f"test_grid_sample_{opset_version}_{ndim}D_{mode}_{padding_mode}_{'align_corners' if align_corners else 'no_align_corners'}" - print(f"TEST(GridSampleTest, {test_name}) {{") - print(f'OpTester test("GridSample", {opset_version});') - print(f'std::string mode = "{onnx_mode}";') - print(f'std::string padding_mode = "{padding_mode}";') - print(f"int64_t align_corners = {onnx_align_corners};") - print(f"std::initializer_list X_shape {{ {', '.join(map(str, input_shape))} }};") - print(f"std::initializer_list X_data { X_data_str };") - print(f"std::initializer_list Grid_shape {{ {', '.join(map(str, grid_shape))} }};") - print(f"std::initializer_list Grid_data { Grid_data_str };") - print(f"std::initializer_list Y_shape {{ {', '.join(map(str, Y_shape))} }};") - print(f"std::initializer_list Y_data { Y_data_str };") - - print('test.AddInput("X", X_shape, X_data);') - print('test.AddInput("Grid", Grid_shape, Grid_data);') - print('test.AddAttribute("mode", mode);') - print('test.AddAttribute("padding_mode", padding_mode);') - print('test.AddAttribute("align_corners", align_corners);') - print('test.AddOutput("Y", Y_shape, Y_data);') - print(f"RunTests(test, GetExecutionProviders({opset_version}));") + spaces = " " + print(f"TYPED_TEST(GridSampleTest, {test_name}) {{") + print(f'{spaces}OpTester test("GridSample", {opset_version});') + print(f'{spaces}std::string mode = "{onnx_mode}";') + print(f'{spaces}std::string padding_mode = "{padding_mode}";') + print(f"{spaces}int64_t align_corners = {onnx_align_corners};") + print(f"{spaces}std::initializer_list X_shape {{ {', '.join(map(str, input_shape))} }};") + print(f"{spaces}std::initializer_list X_data { X_data_str };") + print(f"{spaces}std::initializer_list Grid_shape {{ {', '.join(map(str, grid_shape))} }};") + print(f"{spaces}std::initializer_list Grid_data { Grid_data_str };") + print(f"{spaces}std::initializer_list Y_shape {{ {', '.join(map(str, Y_shape))} }};") + print(f"{spaces}std::initializer_list Y_data { Y_data_str };") + + print(f'{spaces}test.AddInput("X", X_shape, X_data);') + print(f'{spaces}test.AddInput("Grid", Grid_shape, Grid_data);') + print(f'{spaces}test.AddAttribute("mode", mode);') + print(f'{spaces}test.AddAttribute("padding_mode", padding_mode);') + print(f'{spaces}test.AddAttribute("align_corners", align_corners);') + print(f'{spaces}test.AddOutput("Y", Y_shape, Y_data);') + print(f"{spaces}RunTests(test, GetExecutionProviders({opset_version}));") print("}") print("\n") diff --git a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc index 111520ef03e26..84fb6157b8884 100644 --- a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc @@ -10,6 +10,13 @@ namespace onnxruntime { namespace test { +template +class ResizeOpTest : public ::testing::Test { +}; + +using ResizeOpTestTypes = ::testing::Types; +TYPED_TEST_SUITE(ResizeOpTest, ResizeOpTestTypes); + TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_tf_crop_and_resize) { // TODO: Unskip when fixed #41968513 if (DefaultDmlExecutionProvider().get() != nullptr) { @@ -226,26 +233,26 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_without_e test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kDmlExecutionProvider}); } -TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear) { +TYPED_TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear) { auto run_test = [](bool scales_in_initializer) { OpTester test("Resize", 13); - std::vector roi{}; + std::vector roi{}; std::vector scales{1.0f, 1.0f, 0.6f, 0.6f}; test.AddAttribute("mode", "linear"); constexpr int64_t N = 1, C = 1, H = 2, W = 4; - std::vector X = { - 1.0f, 2.0f, 3.0f, 4.0f, - 5.0f, 6.0f, 7.0f, 8.0f}; + std::vector X = { + TypeParam(1.0f), TypeParam(2.0f), TypeParam(3.0f), TypeParam(4.0f), + TypeParam(5.0f), TypeParam(6.0f), TypeParam(7.0f), TypeParam(8.0f)}; - test.AddInput("X", {N, C, H, W}, X); - test.AddInput("roi", {0}, roi); + test.AddInput("X", {N, C, H, W}, X); + test.AddInput("roi", {0}, roi); test.AddInput("scales", {4}, scales, scales_in_initializer); - std::vector Y = {2.66666651f, 4.3333331f}; + std::vector Y = {TypeParam(2.66666651f), TypeParam(4.3333331f)}; - test.AddOutput("Y", {N, C, static_cast(H * scales[2]), static_cast(W * scales[3])}, Y); + test.AddOutput("Y", {N, C, static_cast(H * scales[2]), static_cast(W * scales[3])}, Y); // QNN: result diff // TRT: Segmentation fault in A100 std::unordered_set excluded_providers({kQnnExecutionProvider}); diff --git a/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc b/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc index 83b308b57f26b..a32d43f296250 100644 --- a/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc +++ b/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc @@ -263,17 +263,33 @@ TEST(SliceTest, Slice3D) { 332.0f, 333.0f}); } -template +template +static std::vector GetTypedArray(std::vector inputs, [[maybe_unused]] T v = T(0.f)) { + std::vector inputs_T(inputs.size()); + if constexpr (std::is_same::value) { + return inputs; + } else if constexpr (std::is_integral_v) { + for (size_t i = 0; i < inputs.size(); i++) { + inputs_T[i] = static_cast(inputs[i]); + } + return inputs_T; + } else { + ConvertFloatToMLFloat16(inputs.data(), inputs_T.data(), inputs.size()); + return inputs_T; + } +} + +template static void TestSlice1DIntData() { - static_assert(std::is_integral_v); - RunSliceTest({6}, - {0, 1, 2, 3, 4, 5}, - {2}, - {4}, - {0}, - {}, - {2}, - {2, 3}); + // static_assert(std::is_integral_v); + RunSliceTest({6}, + GetTypedArray({0.f, 1.f, 2.f, 3.f, 4.f, 5.f}), + {2}, + {4}, + {0}, + {}, + {2}, + GetTypedArray({2.f, 3.f})); } TEST(SliceTest, Slice1D_Int32) { @@ -284,6 +300,21 @@ TEST(SliceTest, Slice1D_Int64) { TestSlice1DIntData(); } +TEST(SliceTest, Slice1D_Float) { + TestSlice1DIntData(); +} + +TEST(SliceTest, Slice1D_Float16) { + TestSlice1DIntData(); +} + +template +class SliceTest : public ::testing::Test { +}; + +using SliceTestTypes = ::testing::Types; +TYPED_TEST_SUITE(SliceTest, SliceTestTypes); + TEST(SliceTest, Slice1D_String) { RunSliceTest({6}, {"0", "1", "2", "3", "4", "5"}, @@ -296,16 +327,16 @@ TEST(SliceTest, Slice1D_String) { } // Only Slice V10 can run the following tests -TEST(SliceTest, Slice1D_WithPositiveSteps) { - RunSliceTest({6}, - {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, - {0}, - {6}, - {0}, - {2}, - {3}, - {0.0f, 2.0f, 4.0f}, - true); +TYPED_TEST(SliceTest, Slice1D_WithPositiveSteps) { + RunSliceTest({6}, + GetTypedArray({0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f}), + {0}, + {6}, + {0}, + {2}, + {3}, + GetTypedArray({0.0f, 2.0f, 4.0f}), + true); } // In numpy: diff --git a/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc b/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc index a0c1d675f506f..4954b82690e0f 100644 --- a/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc @@ -4,10 +4,18 @@ #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" #include "core/providers/cpu/tensor/space_depth_ops.h" +#include "core/mlas/inc/mlas.h" namespace onnxruntime { namespace test { +template +class TensorOpTest : public ::testing::Test { +}; + +using TensorOpTestTypes = ::testing::Types; +TYPED_TEST_SUITE(TensorOpTest, TensorOpTestTypes); + TEST(TensorOpTest, SpaceToDepthTest_1) { OpTester test("SpaceToDepth"); constexpr int64_t blocksize = 2; @@ -36,8 +44,8 @@ TEST(TensorOpTest, SpaceToDepthTest_1) { 3.1f, 3.3f}; test.AddOutput("output", {N, C * blocksize * blocksize, H / blocksize, W / blocksize}, result); - // TODO: Test is flaky on QNN EP (CPU backend). Reneable when the QnnCPUBackendTests.DISABLED_SpaceToDepth_Flaky test - // is fixed. + // TODO: Test is flaky on QNN EP (CPU backend). + // Re-enable when the QnnCPUBackendTests.DISABLED_SpaceToDepth_Flaky test is fixed. test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kQnnExecutionProvider}); } @@ -103,8 +111,8 @@ TEST(TensorOpTest, SpaceToDepthTest_2) { 88., 103., 106., 68., 71., 86., 89., 104., 107.}; test.AddOutput("output", {2, 27, 1, 2}, result); - // TODO: Test is flaky on QNN EP (CPU backend). Reneable when the QnnCPUBackendTests.DISABLED_SpaceToDepth_Flaky2 - // test is fixed. + // TODO: Test is flaky on QNN EP (CPU backend). + // Re-enable when the QnnCPUBackendTests.DISABLED_SpaceToDepth_Flaky2 test is fixed. test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kQnnExecutionProvider}); } @@ -259,7 +267,7 @@ TEST(TensorOpTest, DepthToSpaceTest_2) { test.Run(); } -TEST(TensorOpTest, DepthToSpaceTest_3) { +TYPED_TEST(TensorOpTest, DepthToSpaceTest_3) { OpTester test("DepthToSpace", 11); // create an opset 11 model with missing default attribute constexpr int64_t blocksize = 2; test.AddAttribute("blocksize", blocksize); @@ -281,8 +289,6 @@ TEST(TensorOpTest, DepthToSpaceTest_3) { 132., 133., 134., 135., 136., 137., 138., 139., 140., 141., 142., 143.}; - test.AddInput("input", {N, C, H, W}, X); - const std::vector result = { 0., 18., 1., 19., 36., 54., 37., 55., 2., 20., 3., 21., 38., 56., 39., 57., 4., 22., 5., 23., 40., 58., @@ -298,11 +304,24 @@ TEST(TensorOpTest, DepthToSpaceTest_3) { 102., 85., 103., 120., 138., 121., 139., 86., 104., 87., 105., 122., 140., 123., 141., 88., 106., 89., 107., 124., 142., 125., 143.}; - test.AddOutput("output", {2, 3, 6, 4}, result); - test.Run(); + + if constexpr (std::is_same::value) { + test.AddInput("input", {N, C, H, W}, X); + test.AddOutput("output", {2, 3, 6, 4}, result); + } else { + std::vector X_fp16(X.size()); + std::vector result_fp16(result.size()); + ConvertFloatToMLFloat16(result.data(), result_fp16.data(), result.size()); + ConvertFloatToMLFloat16(X.data(), X_fp16.data(), X.size()); + test.AddOutput("output", {2, 3, 6, 4}, result_fp16); + test.AddInput("input", {N, C, H, W}, X_fp16); + } + // TODO: Test is flaky on QNN EP (CPU backend). + // Re-enable when the QnnCPUBackendTests.DISABLED_SpaceToDepth_Flaky test is fixed. + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kQnnExecutionProvider}); } -TEST(TensorOpTest, DepthToSpaceTest_4) { +TYPED_TEST(TensorOpTest, DepthToSpaceTest_4) { OpTester test("DepthToSpace", 11); // create an opset 11 model with attribute present = "DCR" mode constexpr int64_t blocksize = 2; test.AddAttribute("blocksize", blocksize); @@ -325,8 +344,6 @@ TEST(TensorOpTest, DepthToSpaceTest_4) { 132., 133., 134., 135., 136., 137., 138., 139., 140., 141., 142., 143.}; - test.AddInput("input", {N, C, H, W}, X); - const std::vector result = { 0., 18., 1., 19., 36., 54., 37., 55., 2., 20., 3., 21., 38., 56., 39., 57., 4., 22., 5., 23., 40., 58., @@ -342,11 +359,25 @@ TEST(TensorOpTest, DepthToSpaceTest_4) { 102., 85., 103., 120., 138., 121., 139., 86., 104., 87., 105., 122., 140., 123., 141., 88., 106., 89., 107., 124., 142., 125., 143.}; - test.AddOutput("output", {2, 3, 6, 4}, result); - test.Run(); + + if constexpr (std::is_same::value) { + test.AddInput("input", {N, C, H, W}, X); + test.AddOutput("output", {2, 3, 6, 4}, result); + } else { + std::vector X_fp16(X.size()); + std::vector result_fp16(result.size()); + ConvertFloatToMLFloat16(X.data(), X_fp16.data(), X.size()); + ConvertFloatToMLFloat16(result.data(), result_fp16.data(), result.size()); + test.AddInput("input", {N, C, H, W}, X_fp16); + test.AddOutput("output", {2, 3, 6, 4}, result_fp16); + } + + // TODO: Test is flaky on QNN EP (CPU backend). + // Re-enable when the QnnCPUBackendTests.DISABLED_SpaceToDepth_Flaky2 test is fixed. + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kQnnExecutionProvider}); } -TEST(TensorOpTest, DepthToSpaceTest_5) { +TYPED_TEST(TensorOpTest, DepthToSpaceTest_5) { OpTester test("DepthToSpace", 11); // create an opset 11 model with attribute present = "CRD" mode constexpr int64_t blocksize = 2; test.AddAttribute("blocksize", blocksize); @@ -362,15 +393,25 @@ TEST(TensorOpTest, DepthToSpaceTest_5) { 27., 28., 29., 30., 31., 32.}; - test.AddInput("input", {N, C, H, W}, X); - const std::vector result = {0., 9., 1., 10., 2., 11., 18., 27., 19., 28., 20., 29., 3., 12., 4., 13., 5., 14., 21., 30., 22., 31., 23., 32.}; - test.AddOutput("output", {1, 1, 4, 6}, result); - test.Run(); + if constexpr (std::is_same::value) { + test.AddInput("input", {N, C, H, W}, X); + test.AddOutput("output", {1, 1, 4, 6}, result); + } else { + std::vector X_fp16(X.size()); + std::vector result_fp16(result.size()); + ConvertFloatToMLFloat16(X.data(), X_fp16.data(), X.size()); + ConvertFloatToMLFloat16(result.data(), result_fp16.data(), result.size()); + test.AddInput("input", {N, C, H, W}, X_fp16); + test.AddOutput("output", {1, 1, 4, 6}, result_fp16); + } + // TODO: Test is flaky on QNN EP (CPU backend). + // Re-enable when the QnnCPUBackendTests.DISABLED_SpaceToDepth_Flaky2 test is fixed. + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kQnnExecutionProvider}); } TEST(TensorOpTest, DepthToSpaceTest_CRD_Batched) { diff --git a/onnxruntime/test/providers/cpu/tensor/split_op_test.cc b/onnxruntime/test/providers/cpu/tensor/split_op_test.cc index 066302a4a37d1..48872404f08bd 100644 --- a/onnxruntime/test/providers/cpu/tensor/split_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/split_op_test.cc @@ -178,6 +178,17 @@ TEST(SplitOperatorTest, Axis0UnequalSplitFloat) { RunTest(axis, splits, input, outputs, {kTensorrtExecutionProvider}, false, true); } +template +std::vector GetTypedArray(std::vector inputs, [[maybe_unused]] T v = T(0.f)) { + if constexpr (std::is_same::value) { + return inputs; + } else { + std::vector inputs_fp16(inputs.size()); + ConvertFloatToMLFloat16(inputs.data(), inputs_fp16.data(), inputs.size()); + return inputs_fp16; + } +} + TEST(SplitOperatorTest, Axis0UnequalSplitString) { constexpr int64_t axis = 0; std::vector outputs; @@ -222,6 +233,26 @@ TEST(SplitOperatorTest, Axis1EqualSplitFloat) { RunTest(axis, {}, input, outputs, {kTensorrtExecutionProvider}); } +TEST(SplitOperatorTest, Axis1EqualSplitFloat16) { + constexpr int64_t axis = 1; + std::vector> outputs; + + // input shape and data + ShapeAndData input = {{2, 4}, + GetTypedArray({1.f, 2.f, 3.f, 4.f, + 5.f, 6.f, 7.f, 8.f})}; + + outputs.push_back({{2, 2}, + GetTypedArray({1.f, 2.f, + 5.f, 6.f})}); + + outputs.push_back({{2, 2}, + GetTypedArray({3.f, 4.f, + 7.f, 8.f})}); + RunTest(axis, {}, input, outputs, {kTensorrtExecutionProvider}, false, true); + RunTest(axis, {}, input, outputs, {kTensorrtExecutionProvider}); +} + TEST(SplitOperatorTest, Axis1EqualSplitString) { constexpr int64_t axis = 1; std::vector outputs; diff --git a/onnxruntime/test/providers/cpu/tensor/transpose_test.cc b/onnxruntime/test/providers/cpu/tensor/transpose_test.cc index 01dba55ceb8ed..3b46dc3f5d6a2 100644 --- a/onnxruntime/test/providers/cpu/tensor/transpose_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/transpose_test.cc @@ -12,6 +12,13 @@ namespace onnxruntime { namespace test { +template +class TransposeOpTest : public ::testing::Test { +}; + +using TransposeOpTestTypes = ::testing::Types; +TYPED_TEST_SUITE(TransposeOpTest, TransposeOpTestTypes); + TEST(TransposeOpTest, IsTransposeReshapeTest) { std::vector input_dims{1, 2, 3, 4, 1}; std::vector perm{0, 1, 2, 3, 4}; @@ -62,18 +69,27 @@ void TransposeTest(const std::vector& input_shape, } } +template +std::vector GetTypedArray(std::vector inputs, [[maybe_unused]] T v = T(0.f)) { + if constexpr (std::is_same::value) { + return inputs; + } else { + std::vector inputs_fp16(inputs.size()); + ConvertFloatToMLFloat16(inputs.data(), inputs_fp16.data(), inputs.size()); + return inputs_fp16; + } +} + // Test 2 dimensional transpose, with no permutation attribute specified -TEST(TransposeOpTest, TwoDimNoAttr) { +TYPED_TEST(TransposeOpTest, TwoDimNoAttr) { std::vector input_shape({2, 3}); - std::vector input_vals = { - 1.0f, 2.0f, 3.0f, - 4.0f, 5.0f, 6.0f}; + std::vector input_vals = GetTypedArray({1.0f, 2.0f, 3.0f, + 4.0f, 5.0f, 6.0f}); std::vector expected_shape({3, 2}); - std::vector expected_vals = { - 1.0f, 4.0f, - 2.0f, 5.0f, - 3.0f, 6.0f}; + std::vector expected_vals = GetTypedArray({1.0f, 4.0f, + 2.0f, 5.0f, + 3.0f, 6.0f}); TransposeTest(input_shape, input_vals, nullptr, expected_shape, expected_vals, {kTensorrtExecutionProvider}, {7, 21}); // TensorRT: SegFault error diff --git a/onnxruntime/test/util/test_utils.cc b/onnxruntime/test/util/test_utils.cc index b118c8faec0f7..f6d5d133262c4 100644 --- a/onnxruntime/test/util/test_utils.cc +++ b/onnxruntime/test/util/test_utils.cc @@ -59,6 +59,11 @@ void VerifyOutput(const std::string& output_name, ::testing::Pointwise(::testing::FloatNear(fp32_abs_err), tensor.DataAsSpan())); break; } + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: { + EXPECT_THAT(expected_tensor.DataAsSpan(), + ::testing::Pointwise(::testing::FloatNear(fp32_abs_err), tensor.DataAsSpan())); + break; + } default: ORT_THROW("Unhandled data type. Please add 'case' statement for ", element_type); } diff --git a/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md b/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md index bb4cfb2e09dcc..ae0769e7fb93c 100644 --- a/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md +++ b/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md @@ -20,6 +20,7 @@ Keep in sync with doco generated from /docs/execution-providers/CoreML-Execution |ai.onnx:MaxPool|Only 2D Pool is supported currently. 3D and 5D support can be added if needed.| |ai.onnx:Mul|| |ai.onnx:Pow|Only supports cases when both inputs are fp32.| +|ai.onnx:Reciprocal|this ask for a `epislon` (default 1e-4) where onnx don't provide| |ai.onnx:Relu|| |ai.onnx:Reshape|| |ai.onnx:Resize|See [resize_op_builder.cc](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc) implementation. There are too many permutations to describe the valid combinations.| @@ -27,5 +28,6 @@ Keep in sync with doco generated from /docs/execution-providers/CoreML-Execution |ai.onnx:Split|If provided, `splits` must be constant.| |ai.onnx:Sub|| |ai.onnx:Sigmoid|| +|ai.onnx:Sqrt|| |ai.onnx:Tanh|| |ai.onnx:Transpose||