From 20d94648bbb106c74c43ef4023e142dee0342155 Mon Sep 17 00:00:00 2001 From: Julius Tischbein Date: Wed, 11 Sep 2024 01:51:00 +0200 Subject: [PATCH] ConvTranpose using CUDNN Frontend with NHWC support (#21752) ### Description Added CUDNN Frontend and used it for NHWC ConvTranspose op including option for bias fusion. Similar to this [Conv PR](https://github.com/microsoft/onnxruntime/pull/19470) ### Backward compatible If ORT is built with cuDNN 8, cuDNN frontend will not be built into binary. Old kernels (using cudnn backend APIs) are used. ### Major Changes For cuDNN 9, we will enable cudnn frontend to fuse data gradient convolution and bias when a provider option fuse_conv_bias=1. ### Potential Issues cuDNN frontend uses TF32 by default. It can be disabled using use_tf32 cuda provider option, but in the case cuDNN frontend encounters issues building an operation graph it will fallback to using TF32. ### Follow ups This is one of the PRs that target to enable NHWC, here the ConvTranspose operation in CUDA EP by default if device supports it. There are other changes will follow up to make it possible. (1) Enable prefer_nhwc by default for device with sm >= 70. (2) Change fuse_conv_bias=1 by default after more testing. (3) Add other NHWC operators (like Resize or UpSample). ### Motivation and Context The new CUDNN Frontend library provides the functionality to fuse operations and provides new heuristics for kernel selection. Here it fuses the convolution data gradient operation (ConvTranspose) with the pointwise bias operation. ### Minor Change In the CUDA convolution operation was a small bug when `GetCudnnConv1dPadToNc1d ` was enabled. --- .../providers/cuda/cuda_execution_provider.cc | 3 +- onnxruntime/core/providers/cuda/nn/conv.cc | 2 +- .../core/providers/cuda/nn/conv_transpose.cc | 626 +++++++++++------- .../core/providers/cuda/nn/conv_transpose.h | 29 + .../core/providers/cuda/nn/conv_transpose_8.h | 266 ++++++++ 5 files changed, 702 insertions(+), 224 deletions(-) create mode 100644 onnxruntime/core/providers/cuda/nn/conv_transpose_8.h diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index b54c572556220..82b29c7b0562e 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -2473,7 +2473,8 @@ static bool RNNNeedFallbackToCPU(const onnxruntime::Node& node, return false; } -static bool ConvTransposeNeedFallbackToCPU(const onnxruntime::Node& node, const logging::Logger& logger, +static bool ConvTransposeNeedFallbackToCPU([[maybe_unused]] const onnxruntime::Node& node, + [[maybe_unused]] const logging::Logger& logger, [[maybe_unused]] const GraphViewer& graph_viewer, [[maybe_unused]] const bool prefer_nhwc) { const auto& node_attributes = node.GetAttributes(); diff --git a/onnxruntime/core/providers/cuda/nn/conv.cc b/onnxruntime/core/providers/cuda/nn/conv.cc index 95ba698b707ac..cc76198dc3ae9 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.cc +++ b/onnxruntime/core/providers/cuda/nn/conv.cc @@ -385,7 +385,7 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected if (cuda_ep->GetCudnnConv1dPadToNc1d()) { x_dims_cudnn.insert(x_dims_cudnn.begin() + 2, 1); y_dims_cudnn.insert(y_dims_cudnn.begin() + 2, 1); - w_dims_cudnn.insert(w_dims.begin() + 2, 1); + w_dims_cudnn.insert(w_dims_cudnn.begin() + 2, 1); pads.insert(pads.begin() + kernel_rank, 0); pads.insert(pads.begin(), 0); kernel_shape.insert(kernel_shape.begin(), 1); diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc index bac99d6a81ed2..d4876e1714861 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc @@ -7,6 +7,11 @@ #include "conv_transpose.h" #include "core/providers/cuda/tensor/transpose.h" +#if CUDNN_MAJOR < 9 +// if compiled with cuDNN 8 we want to use the legacy cuDNN API +#include "conv_transpose_8.h" +#endif + // To suppress FP static analyzer warnings: // https://msdata.visualstudio.com/Vienna/_workitems/edit/1944928 and // https://msdata.visualstudio.com/Vienna/_workitems/edit/1944950 @@ -38,48 +43,42 @@ REGISTER_KERNEL_TYPED(float, kMSInternalNHWCDomain, true) REGISTER_KERNEL_TYPED(MLFloat16, kMSInternalNHWCDomain, true) #endif -template -Status ConvTranspose::ComputeInternal(OpKernelContext* context) const { - return DoConvTranspose(context, false); -} - +// First input (in this case X) is in case NHWC == true also in NHWC format, the other inputs in NCHW template Status ConvTranspose::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool& is_packed, - PrePackedWeights* prepacked_weights) { + [[maybe_unused]] PrePackedWeights* prepacked_weights) { is_packed = false; // only layout of weight input is adjusted via PrePack - if constexpr (NHWC) { // InputTensors::IN_W - if (input_idx == 1) { + if constexpr (NHWC) { + if (is_nhwc_domain_ && input_idx == 1) { // InputTensors::IN_W + // Transpose from {M, C/group, kH, kW} to {M, kH, kW, C/group} auto orig_shape = tensor.Shape(); - const auto rank = orig_shape.NumDimensions(); - - InlinedVector perm; - TensorShapeVector new_dims; - - // Input is { N, C, ...}. Output is { N, M, ...}. 'input channels' is C. 'output channels' is M. - // Transpose the output channels related dimension (M/group) to be last. Leave the input channels as-is. - if (rank == 3) { - // Transpose from {C, M/group, k1} to {C, k1, M/group} - perm = {0, 2, 1}; - new_dims = TensorShapeVector{orig_shape[0], orig_shape[2], orig_shape[1]}; - } else if (rank == 4) { - // Transpose from {C, M/group, kH, kW} to {C, kH, kW, M/group} - perm = {0, 2, 3, 1}; - new_dims = TensorShapeVector{orig_shape[0], orig_shape[2], orig_shape[3], orig_shape[1]}; - } else if (rank == 5) { - // Transpose from {C, M/group, k1, k2, k3} to {C, k1, k2, k3, M/group} - perm = {0, 2, 3, 4, 1}; - new_dims = TensorShapeVector{orig_shape[0], orig_shape[2], orig_shape[3], orig_shape[4], orig_shape[1]}; - } + auto shape_size = orig_shape.GetDims().size(); + + InlinedVector perm; + perm.push_back(0); + for (size_t i = 2; i < shape_size; i++) perm.push_back(i); + perm.push_back(1); + gsl::span permutation(perm.data(), shape_size); - gsl::span permutation(perm.data(), rank); - W_ = Tensor::Create(tensor.DataType(), TensorShape(new_dims), std::move(alloc)); + TensorShapeVector nhwc_dims; + for (size_t i = 0; i < shape_size; i++) { + nhwc_dims.push_back(orig_shape[perm[i]]); + } - ORT_RETURN_IF_ERROR(cuda::Transpose::DoTranspose(GetDeviceProp(), DefaultCudaStream(), DefaultCublasHandle(), - permutation, tensor, *W_)); + W_ = Tensor::Create(tensor.DataType(), TensorShape(nhwc_dims), std::move(alloc)); + auto status = cuda::Transpose::DoTranspose(GetDeviceProp(), + DefaultCudaStream(), + DefaultCublasHandle(), + permutation, tensor, *W_); + if (!status.IsOK()) { + return status; + } CUDA_CALL_THROW(cudaStreamSynchronize(DefaultCudaStream())); is_packed = true; + } else { + W_already_nhwc = true; } } else { ORT_UNUSED_PARAMETER(tensor); @@ -91,236 +90,419 @@ Status ConvTranspose::PrePack(const Tensor& tensor, int input_idx, Allo return Status::OK(); } -template -Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_padding) const { - typedef typename ToCudaType::MappedType CudaT; +#if CUDNN_MAJOR >= 9 +#if !defined(__CUDACC__) + +template +Status ConvTranspose::CreateCudnnFeExecutionPlan(const onnxruntime::TensorShapeVector& x_dims, + const onnxruntime::TensorShapeVector& w_dims, + const Tensor* B, + const TensorShapeVector& y_dims, + cudnnContext* handle, + const cudnn_frontend::HeurMode_t heur_mode, + const std::vector& pads, + const std::vector& strides, + const std::vector& dilations, + const bool fuse_bias, + const bool fuse_act, + const bool w_in_nhwc, + const bool use_tf32) const { + s_.bias_fused = fuse_bias; + s_.act_fused = fuse_act; + s_.variant_pack.clear(); // clear variant pack, as stored pointers to tensors change + s_.cudnn_fe_graph = std::make_unique(); + cudnn_frontend::DataType_t data_type = CudnnFeTensor::GetDataType(); + s_.cudnn_fe_graph->set_io_data_type(data_type).set_intermediate_data_type(data_type); + if (data_type == cudnn_frontend::DataType_t::HALF) { + s_.cudnn_fe_graph->set_compute_data_type(cudnn_frontend::DataType_t::FLOAT); + } else { + s_.cudnn_fe_graph->set_compute_data_type(data_type); + } - const Tensor* X = context->Input(0); - const TensorShape& x_shape = X->Shape(); - auto x_dims = x_shape.AsShapeVector(); - auto x_data = reinterpret_cast(X->Data()); - - auto x_dimensions = X->Shape().NumDimensions(); - if (x_dimensions < 3 || x_dimensions > 5) { - // TODO: the error message should tell which operator raises it. - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input X must be 3-, 4- or 5-dimensional.", - " X: ", X->Shape().ToString().c_str()); + s_.cudnn_fe_X = s_.cudnn_fe_graph->tensor(CudnnFeTensor(x_dims, "x", data_type, Layout == LAYOUT_NHWC).Get()); + s_.cudnn_fe_W = s_.cudnn_fe_graph->tensor(CudnnFeTensor(w_dims, "w", data_type, w_in_nhwc).Get()); + + auto conv_options = cudnn_frontend::graph::Conv_dgrad_attributes() + .set_pre_padding(std::vector(pads.begin(), + pads.begin() + pads.size() / 2)) + .set_post_padding(std::vector(pads.begin() + pads.size() / 2, pads.end())) + .set_stride(strides) + .set_dilation(dilations); + s_.cudnn_fe_conv_Y = s_.cudnn_fe_graph->conv_dgrad(s_.cudnn_fe_X, s_.cudnn_fe_W, conv_options); + auto cudnn_fe_y_tensor = CudnnFeTensor(y_dims, "y", data_type, Layout == LAYOUT_NHWC).Get(); + + if (B == nullptr) { + s_.cudnn_fe_Y = s_.cudnn_fe_conv_Y; + } else { + int64_t bias_size; + if (B != nullptr) { + bias_size = B->Shape()[0]; + } else { + bias_size = w_dims[0]; + } + + if (fuse_bias) { + onnxruntime::TensorShapeVector b_dims; + for (size_t i = 0; i < x_dims.size(); i++) { + b_dims.push_back(i == 1 ? bias_size : 1); + } + auto bias_tensor = CudnnFeTensor(b_dims, "b", data_type, Layout == LAYOUT_NHWC).Get(); + auto bias_options = cudnn_frontend::graph::Pointwise_attributes().set_mode(cudnn_frontend::PointwiseMode_t::ADD); + s_.cudnn_fe_B = s_.cudnn_fe_graph->tensor(bias_tensor); + s_.cudnn_fe_Y = s_.cudnn_fe_graph->pointwise(s_.cudnn_fe_conv_Y, s_.cudnn_fe_B, bias_options); + } else { + s_.cudnn_fe_Y = s_.cudnn_fe_conv_Y; + + TensorShapeVector b_dims(y_dims.size(), 1); + TensorShapeVector b_strides(y_dims.size(), 1); + b_dims[1] = bias_size; + b_strides[0] = bias_size; + + ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_dims, CudnnTensor::GetDataType(), b_strides)); + ORT_RETURN_IF_ERROR(s_.y_tensor.Set(y_dims, CudnnTensor::GetDataType(), cudnn_fe_y_tensor.get_stride())); + + /* Creating an own CUDNN Frontend graph for the bias addition. + s_.cudnn_fe_bias_graph = std::make_unique(); + s_.cudnn_fe_bias_graph->set_io_data_type(data_type) + .set_compute_data_type(data_type == cudnn_frontend::DataType_t::HALF ? + cudnn_frontend::DataType_t::FLOAT : data_type) + .set_intermediate_data_type(data_type); + s_.cudnn_fe_bias_X = s_.cudnn_fe_bias_graph->tensor(CudnnFeTensor(y_dims, "x", data_type).Get()); + + s_.cudnn_fe_B = s_.cudnn_fe_bias_graph->tensor(bias_tensor); + s_.cudnn_fe_bias_Y = s_.cudnn_fe_bias_graph->pointwise(s_.cudnn_fe_bias_X, s_.cudnn_fe_B, bias_options); + s_.cudnn_fe_bias_Y->set_output(true); + + CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_bias_graph->validate()); + CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_bias_graph->build_operation_graph(handle)); + CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_bias_graph->create_execution_plans({heur_mode})); + CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_bias_graph->check_support(handle)); + CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_bias_graph->build_plans(handle));*/ + } + } + if (fuse_act && s_.cudnn_fe_act_attr.has_value()) { + auto& activation_attr = s_.cudnn_fe_act_attr.value(); + s_.cudnn_fe_Y = s_.cudnn_fe_graph->pointwise(s_.cudnn_fe_Y, activation_attr); } - // use pre-packed W if available - const Tensor* W = W_ ? W_.get() : context->Input(1); + s_.cudnn_fe_Y->set_dim(cudnn_fe_y_tensor.get_dim()); + s_.cudnn_fe_Y->set_stride(cudnn_fe_y_tensor.get_stride()); + s_.cudnn_fe_Y->set_output(true); + + try { + CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->validate()); + CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->build_operation_graph(handle)); + CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->create_execution_plans({heur_mode})); + } catch (const std::exception& ex) { + std::string message = MakeString("Failed to initialize CUDNN Frontend", ex.what(), + "with the cudnn frontend json:\n", s_.cudnn_fe_graph->print()); + return Status(common::StatusCategory::ONNXRUNTIME, common::StatusCode::EP_FAIL, message); + } - const TensorShape& w_shape = W->Shape(); - TensorShapeVector w_dims = w_shape.AsShapeVector(); - auto w_data = reinterpret_cast(W->Data()); + if (!use_tf32) s_.cudnn_fe_graph->deselect_numeric_notes({cudnn_frontend::NumericalNote_t::TENSOR_CORE}); + + try { + CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->check_support(handle)); + CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->build_plans(handle)); + } catch (const std::exception& ex) { + if (!fuse_bias && !fuse_act && use_tf32) { + std::string message = MakeString("OP not supported by CUDNN Frontend", ex.what(), + "with the cudnn frontend json:\n", s_.cudnn_fe_graph->print()); + return Status(common::StatusCategory::ONNXRUNTIME, common::StatusCode::EP_FAIL, message); + } + + // Try fallback. + return CreateCudnnFeExecutionPlan(x_dims, w_dims, B, y_dims, handle, heur_mode, + pads, strides, dilations, false, false, w_in_nhwc, true); + } + + s_.workspace_bytes = s_.cudnn_fe_graph->get_workspace_size(); + return Status::OK(); +} + +#endif + +template +Status ConvTranspose::UpdateState(OpKernelContext* context, bool dynamic_padding) const { + constexpr bool channels_last = Layout == LAYOUT_NHWC; size_t num_inputs = OpKernel::Node().InputDefs().size(); bool has_bias = dynamic_padding ? num_inputs == 4 : num_inputs == 3; - CudaT* y_data = nullptr; + // set X + const Tensor* X = context->Input(0); + const TensorShape& x_shape = X->Shape(); + // X incl. x_dims is in NHWC Format iff. NHWC == true + const auto x_dims = x_shape.AsShapeVector(); + + s_.x_data = reinterpret_cast(X->Data()); + s_.element_size = X->DataType()->Size(); + + // set W + bool w_in_nhwc; + const Tensor* W; + if (!W_) { + W = context->Input(1); + w_in_nhwc = false; + // Dims and memory layout are in NCHW format + } else { + W = W_.get(); + w_in_nhwc = channels_last; + // W got prepacked, therefore if NHWC == true, then dims and memory layout are in NHWC + } + const TensorShape& w_shape = W->Shape(); + onnxruntime::TensorShapeVector w_dims = w_shape.AsShapeVector(); + s_.w_data = reinterpret_cast(W->Data()); + + // set B + // Always in NCHW format + const Tensor* B = nullptr; + if (has_bias) { + B = context->Input(dynamic_padding ? 3 : 2); + s_.b_data = reinterpret_cast(B->Data()); + } else { + s_.b_data = nullptr; + } - const auto* cuda_ep = static_cast(Info().GetExecutionProvider()); + const Tensor* Pads = dynamic_padding ? context->Input(2) : nullptr; - // convert 1D to 2D - if (x_dimensions == 3) { - // we can either add a fake H or W dimension with a value of 1. to be consistent with the Conv behavior we use - // GetCudnnConv1dPadToNc1d to determine which is added. - // see Conv::UpdateState in /onnxruntime/core/providers/cuda/nn/conv.cc for more details. - if (cuda_ep->GetCudnnConv1dPadToNc1d()) { - // add fake H dimension - const auto insert_at = NHWC ? 1 : 2; + bool input_dims_changed = (s_.last_x_dims != x_dims); + bool w_dims_changed = (s_.last_w_dims != w_dims); + if (input_dims_changed || w_dims_changed) { + if (input_dims_changed) + s_.last_x_dims = gsl::make_span(x_dims); - // NCHW: N, C, d1 -> N, C, 1, d1 - // NHWC: N, d1, C -> N, 1, d1, C - x_dims.insert(x_dims.begin() + insert_at, 1); + if (w_dims_changed) { + s_.last_w_dims = gsl::make_span(w_dims); + } - // 'M' is channels dim in CUDA implementation - // NCHW: C, M/g, k1 -> C, M/g, 1, k1 - // NHWC: C, k1, M/g -> C, 1, k1, M/g - w_dims.insert(w_dims.begin() + insert_at, 1); - } else { - // add fake W dimension - const auto insert_at = NHWC ? 2 : 3; + // The following code is from ConvTransposeAttributes::PrepareForCompute - // NCHW: N, C, d1 -> N, C, d1, 1 - // NHWC: N, d1, C -> N, d1, 1, C - x_dims.insert(x_dims.begin() + insert_at, 1); + const int rank = static_cast(X->Shape().NumDimensions()); + TensorShape input_shape = X->Shape().Slice(channels_last ? 1 : 2, channels_last ? rank - 1 : rank); + const int64_t num_input_channels = channels_last ? X->Shape()[rank - 1] : X->Shape()[1]; + const int64_t N = X->Shape()[0]; + const int64_t num_output_channels_multiplier = w_in_nhwc ? w_shape[rank - 1] : w_shape[1]; + const int64_t num_output_channels = num_output_channels_multiplier * conv_transpose_attrs_.group; - // NCHW: C, M/g, k1 -> C, M/g, k1, 1 - // NHWC: C, k1, M/g -> C, k1, 1, M/g - w_dims.insert(w_dims.begin() + insert_at, 1); + if (conv_transpose_attrs_.group <= 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "group count is <= 0", + " group: ", conv_transpose_attrs_.group); } - } - { - std::lock_guard lock(s_.mutex); - // CUDNN_CONFIG_RETURN_IF_ERROR(cudnnSetStream(CudnnHandle(), Stream(context))); - // TODO: add a global cache if need to handle cases for multiple frames running simultaneously with different batch_size - bool input_dims_changed = (s_.last_x_dims.AsShapeVector() != x_dims); - bool w_dims_changed = (s_.last_w_dims.AsShapeVector() != w_dims); - if (input_dims_changed || w_dims_changed) { - if (input_dims_changed) { - s_.last_x_dims = gsl::make_span(x_dims); - } + if (X->Shape().NumDimensions() != w_shape.NumDimensions()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "X num_dims does not match W num_dims.", + " X: ", X->Shape().ToString().c_str(), + " W: ", w_shape.ToString().c_str()); + } - if (w_dims_changed) { - s_.last_w_dims = gsl::make_span(w_dims); - s_.cached_benchmark_results.clear(); - } + if (w_shape[0] != num_input_channels) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "filter number not equal to input channel number.", + " filter_number: ", w_shape[0], + " num_input_channels: ", num_input_channels); + } - ConvTransposeAttributes::Prepare p; - // PrePack moves the M/group dimension of W to the end, with 'M' being interpreted as 'output channels' - const bool transposed_input_channels = false; - ORT_RETURN_IF_ERROR( - conv_transpose_attrs_.PrepareForCompute(context, has_bias, p, dynamic_padding, &w_shape, NHWC, transposed_input_channels)); - - auto y_dims = p.Y->Shape().AsShapeVector(); - if (x_dimensions == 3) { - if (cuda_ep->GetCudnnConv1dPadToNc1d()) { - // add fake H dimension of 1 - // NCHW: N, M, d1 -> N, M, 1, d1 or - // NHWC: N, d1, M -> N, 1, d1, M - y_dims.insert(y_dims.begin() + (NHWC ? 1 : 2), 1); - p.kernel_shape.insert(p.kernel_shape.begin(), 1); - p.pads.insert(p.pads.begin(), 0); - p.pads.insert(p.pads.begin() + 2, 0); - p.strides.insert(p.strides.begin(), 1); - p.dilations.insert(p.dilations.begin(), 1); - } else { - // add fake W dimension of 1 - // NCHW: N, M, d1 -> N, M, d1, 1 or - // NHWC: N, d1, M -> N, d1, 1, M - y_dims.insert(y_dims.begin() + (NHWC ? 2 : 3), 1); - p.kernel_shape.push_back(1); - p.pads.insert(p.pads.begin() + 1, 0); - p.pads.push_back(0); - p.strides.push_back(1); - p.dilations.push_back(1); - } - } + // it looks like num_output_channels is really k*group similar to how in the conv case + // num_input_channels is k*group. hence removing the check for num_output_channels here. - s_.y_dims = gsl::make_span(y_dims); + if (num_input_channels % conv_transpose_attrs_.group != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input channels is not divisible by group.", + " num_input_channels: ", num_input_channels, + " group: ", conv_transpose_attrs_.group); + } - if (w_dims_changed) { - if constexpr (NHWC) { - ORT_RETURN_IF_ERROR(s_.w_desc.Set(CUDNN_TENSOR_NHWC, CudnnTensor::GetDataType(), - static_cast(w_dims[0]), static_cast(w_dims[3]), - static_cast(w_dims[1]), static_cast(w_dims[2]))); - } else { - ORT_RETURN_IF_ERROR(s_.w_desc.Set(w_dims, CudnnTensor::GetDataType())); - } - } + TensorShapeVector kernel_shape; + ORT_RETURN_IF_ERROR(conv_transpose_attrs_.ComputeKernelShape(w_shape, kernel_shape, w_in_nhwc)); - // Special case when there is a dim value of 0 in the shape. - // Return only after we have cached the following for subsequent runs : - // 1) `w_dims` in the `w_desc` - // 2) `y_dims` in s_.y_dims - if (p.Y->Shape().Size() == 0) { - return Status::OK(); - } + const size_t kernel_rank = kernel_shape.size(); - if constexpr (NHWC) { - ORT_RETURN_IF_ERROR(s_.x_tensor.Set(CUDNN_TENSOR_NHWC, CudnnTensor::GetDataType(), - static_cast(x_dims[0]), static_cast(x_dims[3]), - static_cast(x_dims[1]), static_cast(x_dims[2]))); - ORT_RETURN_IF_ERROR(s_.y_tensor.Set(CUDNN_TENSOR_NHWC, CudnnTensor::GetDataType(), - static_cast(y_dims[0]), static_cast(y_dims[3]), - static_cast(y_dims[1]), static_cast(y_dims[2]))); - } else { - ORT_RETURN_IF_ERROR(s_.x_tensor.Set(x_dims, CudnnTensor::GetDataType())); - ORT_RETURN_IF_ERROR(s_.y_tensor.Set(y_dims, CudnnTensor::GetDataType())); + TensorShapeVector local_output_padding(conv_transpose_attrs_.output_padding); + if (local_output_padding.empty()) { + local_output_padding.resize(kernel_shape.size(), 0); + } + ConvPadVector pads; + pads.reserve(2 * (input_shape.NumDimensions())); + if (dynamic_padding) { + for (int64_t i = 0; i < Pads->Shape().SizeFromDimension(0); ++i) { + pads.push_back(Pads->Data()[i]); } + } else { + pads.assign(conv_transpose_attrs_.pads.begin(), conv_transpose_attrs_.pads.end()); + } + if (pads.empty()) { + pads.resize(kernel_shape.size() * 2, 0); + } + TensorShapeVector dilations(conv_transpose_attrs_.dilations); + if (dilations.empty()) { + dilations.resize(kernel_shape.size(), 1); + } + TensorShapeVector strides(conv_transpose_attrs_.strides); + if (strides.empty()) { + strides.resize(kernel_shape.size(), 1); + } - cudnnConvolutionMode_t mode = CUDNN_CROSS_CORRELATION; - ORT_RETURN_IF_ERROR(s_.conv_desc.Set(p.kernel_shape.size(), p.pads, p.strides, p.dilations, - gsl::narrow_cast(conv_transpose_attrs_.group), mode, - CudnnTensor::GetDataType(), - UseTF32())); - - if (has_bias) { - const auto& b_shape = p.B->Shape(); - ORT_RETURN_IF_NOT(b_shape.NumDimensions() == 1, "bias should be 1D"); - TensorShapeVector b_dims(2 + p.kernel_shape.size()); - b_dims[0] = 1; // N - b_dims[NHWC ? 3 : 1] = b_shape[0]; // C - for (size_t i = 0; i < p.kernel_shape.size(); i++) { - b_dims[(NHWC ? 1 : 2) + i] = 1; - } - - ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_dims, CudnnTensor::GetDataType(), NHWC)); - } + TensorShapeVector y_dims; - y_data = reinterpret_cast(p.Y->MutableData()); - - if (!s_.cached_benchmark_results.contains(x_dims)) { - IAllocatorUniquePtr algo_search_workspace = - GetScratchBuffer(AlgoSearchWorkspaceSize, context->GetComputeStream()); - - // set math type to tensor core before algorithm search - if constexpr (std::is_same::value) { - CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_TENSOR_OP_MATH)); - } else if constexpr (std::is_same::value) { - if (!UseTF32()) { - CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_FMA_MATH)); - } - } - - cudnnConvolutionBwdDataAlgoPerf_t perf; - int algo_count = 1; - CUDNN_RETURN_IF_ERROR(cudnnFindConvolutionBackwardDataAlgorithmEx( - GetCudnnHandle(context), s_.w_desc, w_data, s_.x_tensor, x_data, s_.conv_desc, s_.y_tensor, y_data, 1, - &algo_count, &perf, algo_search_workspace.get(), AlgoSearchWorkspaceSize)); - s_.cached_benchmark_results.insert(x_dims, {perf.algo, perf.memory, perf.mathType}); - } + conv_transpose_attrs_.ComputePadsAndOutputShape(input_shape, num_output_channels, kernel_shape, + strides, dilations, local_output_padding, N, &pads, &y_dims, channels_last); - const auto& perf = s_.cached_benchmark_results.at(x_dims); - CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, perf.mathType)); - s_.algo = perf.algo; - s_.workspace_bytes = perf.memory; - } + s_.y_dims = gsl::make_span(y_dims); + s_.Y = context->Output(0, s_.y_dims); - // The following block will be executed in case there has been no change in the shapes of the - // input and the filter compared to the previous run - if (!y_data) { - auto y_dims = s_.y_dims.AsShapeVector(); - if (x_dimensions == 3) { - if (cuda_ep->GetCudnnConv1dPadToNc1d()) { - // erase the fake H dimension - y_dims.erase(y_dims.begin() + (NHWC ? 1 : 2)); - } else { - // erase the fake W dimension - y_dims.erase(y_dims.begin() + (NHWC ? 2 : 3)); - } - } + s_.y_data = reinterpret_cast(s_.Y->MutableData()); + const CUDAExecutionProvider* cuda_ep = + static_cast(this->Info().GetExecutionProvider()); - Tensor* Y = context->Output(0, TensorShape(y_dims)); - y_data = reinterpret_cast(Y->MutableData()); + TensorShapeVector x_dims_cudnn{x_dims.begin(), x_dims.end()}; + TensorShapeVector y_dims_cudnn{y_dims.begin(), y_dims.end()}; + TensorShapeVector w_dims_cudnn{w_dims.begin(), w_dims.end()}; - // Bail out early if one of the output dimensions is zero. - if (Y->Shape().Size() == 0) { - return Status::OK(); + if constexpr (channels_last) { + x_dims_cudnn.insert(x_dims_cudnn.begin() + 1, *(x_dims_cudnn.end() - 1)); + y_dims_cudnn.insert(y_dims_cudnn.begin() + 1, *(y_dims_cudnn.end() - 1)); + x_dims_cudnn.erase(x_dims_cudnn.end() - 1); + y_dims_cudnn.erase(y_dims_cudnn.end() - 1); + + if (w_in_nhwc) { + w_dims_cudnn.insert(w_dims_cudnn.begin() + 1, *(w_dims_cudnn.end() - 1)); + w_dims_cudnn.erase(w_dims_cudnn.end() - 1); } } - const auto alpha = Consts::One; - const auto beta = Consts::Zero; + if (kernel_rank < 2) { + // TODO: Explore padding the provided input shape [N, C, D] to [N, C, 1, D] + // especially for EXHAUSTIVE algo search which may result in a better algo selection. + // ORTModule uses different algo search options (HEURISTIC, and use max workspace size) compared to + // inference build (EXHAUSTIVE, 32M workspace size). We observed better perf when we pad input shape + // [N,C,D] to [N,C,1,D], expecially on A100, and especially for ConvGrad. + // PyTorch also pads to [N,C,1,D]. For inference build, we still pad it to [N, C, D, 1] as this seems + // to be the sweet spot for all algo search options: EXHAUSTIVE, HEURISTIC, and DEFAULT. + // See PR #7348 and #7702 for more context. + if (cuda_ep->GetCudnnConv1dPadToNc1d()) { + x_dims_cudnn.insert(x_dims_cudnn.begin() + 2, 1); + y_dims_cudnn.insert(y_dims_cudnn.begin() + 2, 1); + w_dims_cudnn.insert(w_dims_cudnn.begin() + 2, 1); + pads.insert(pads.begin() + kernel_rank, 0); + pads.insert(pads.begin(), 0); + kernel_shape.insert(kernel_shape.begin(), 1); + strides.insert(strides.begin(), 1); + dilations.insert(dilations.begin(), 1); + } else { + x_dims_cudnn.push_back(1); + y_dims_cudnn.push_back(1); + w_dims_cudnn.push_back(1); + pads.insert(pads.begin() + kernel_rank, 0); + pads.insert(pads.end(), 0); + kernel_shape.push_back(1); + strides.push_back(1); + dilations.push_back(1); + } + } + + // We must delay returning early until here so that the weight dims have been cached properly + if (s_.Y->Shape().Size() == 0) { + return Status::OK(); + } - IAllocatorUniquePtr workspace = GetScratchBuffer(s_.workspace_bytes, context->GetComputeStream()); + auto handle = GetCudnnHandle(context); + + int cudnn_conv_algo = cuda_ep->GetCudnnConvAlgo(); +#if !defined(__CUDACC__) + cudnn_frontend::HeurMode_t heur_mode; + switch (cudnn_conv_algo) { + case 0: + heur_mode = cudnn_frontend::HeurMode_t::B; + break; + case 1: + heur_mode = cudnn_frontend::HeurMode_t::A; + break; + case 2: + heur_mode = cudnn_frontend::HeurMode_t::FALLBACK; + break; + default: + heur_mode = cudnn_frontend::HeurMode_t::A; + break; + } - CUDNN_RETURN_IF_ERROR(cudnnConvolutionBackwardData(GetCudnnHandle(context), &alpha, s_.w_desc, w_data, s_.x_tensor, - x_data, s_.conv_desc, s_.algo, workspace.get(), - s_.workspace_bytes, &beta, s_.y_tensor, y_data)); + auto use_tf32 = cuda_ep->UseTF32(); + const auto fuse_bias = cuda_ep->IsFuseConvBias() || is_fused_node_; + const auto fuse_act = is_fused_node_; + + ORT_RETURN_IF_ERROR(CreateCudnnFeExecutionPlan(x_dims_cudnn, w_dims_cudnn, B, y_dims_cudnn, handle, heur_mode, + std::vector(pads.begin(), + pads.end()), + std::vector(strides.begin(), + strides.end()), + std::vector(dilations.begin(), + dilations.end()), + fuse_bias, fuse_act, w_in_nhwc, use_tf32)); +#endif + } else { + // set Y + s_.Y = context->Output(0, s_.y_dims); + if (s_.Y->Shape().Size() == 0) { + return Status::OK(); + } + s_.y_data = reinterpret_cast(s_.Y->MutableData()); + } + return Status::OK(); +} - if (has_bias) { - const Tensor* B = dynamic_padding ? context->Input(3) : context->Input(2); - auto b_data = reinterpret_cast(B->Data()); - CUDNN_RETURN_IF_ERROR( - cudnnAddTensor(GetCudnnHandle(context), &alpha, s_.b_tensor, b_data, &alpha, s_.y_tensor, y_data)); +template +Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_padding) const { + std::lock_guard lock(s_.mutex); + ORT_RETURN_IF_ERROR(UpdateState(context, dynamic_padding)); + if (s_.Y->Shape().Size() == 0) { + return Status::OK(); + } + const auto alpha = onnxruntime::cuda::Consts::One; + auto cudnn_handle = GetCudnnHandle(context); +#if !defined(__CUDACC__) + s_.variant_pack.insert_or_assign(s_.cudnn_fe_X, const_cast(s_.x_data)); + s_.variant_pack.insert_or_assign(s_.cudnn_fe_W, const_cast(s_.w_data)); + s_.variant_pack.insert_or_assign(s_.cudnn_fe_Y, s_.y_data); + if (s_.bias_fused && s_.b_data != nullptr) { + s_.variant_pack.insert_or_assign(s_.cudnn_fe_B, const_cast(s_.b_data)); + } + if (s_.bias_fused && s_.z_data != nullptr) { + s_.variant_pack.insert_or_assign(s_.cudnn_fe_Z, const_cast(s_.z_data)); + if (Layout == LAYOUT_NCHW && s_.z_data == s_.y_data) { + // memset Z if it's required for a succesful fusion + CUDA_RETURN_IF_ERROR(cudaMemset(s_.y_data, 0, s_.Y->SizeInBytes())); } } + auto ws = GetWorkSpace(context->GetComputeStream()); + + CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_graph->execute(cudnn_handle, + s_.variant_pack, + ws.get())); + + if (!s_.bias_fused && s_.z_data != nullptr) { + CUDNN_RETURN_IF_ERROR(cudnnAddTensor(cudnn_handle, &alpha, s_.z_tensor, s_.z_data, + &alpha, s_.y_tensor, s_.y_data)); + } + if (!s_.bias_fused && s_.b_data != nullptr) { + CUDNN_RETURN_IF_ERROR(cudnnAddTensor(cudnn_handle, &alpha, s_.b_tensor, s_.b_data, + &alpha, s_.y_tensor, s_.y_data)); + + /* For the standalone bias addition graph. + s_.variant_pack_bias.insert_or_assign(s_.cudnn_fe_bias_X, s_.y_data); + s_.variant_pack_bias.insert_or_assign(s_.cudnn_fe_bias_Y, s_.y_data); + s_.variant_pack_bias.insert_or_assign(s_.cudnn_fe_B, const_cast(s_.b_data)); + CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_bias_graph->execute(cudnn_handle, + s_.variant_pack_bias, + GetWorkSpace(context->GetComputeStream()).get()));*/ + } +#endif return Status::OK(); } +#endif + +template +Status ConvTranspose::ComputeInternal(OpKernelContext* context) const { + return DoConvTranspose(context, false); +} } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose.h b/onnxruntime/core/providers/cuda/nn/conv_transpose.h index 71ad3ee6e2147..1a6957164d22f 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose.h +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose.h @@ -18,6 +18,8 @@ namespace cuda { template class ConvTranspose : public CudaKernel { public: + using CudaT = typename ToCudaType::MappedType; + ConvTranspose(const OpKernelInfo& info) : CudaKernel(info), conv_transpose_attrs_(info) {}; Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool& is_packed, [[maybe_unused]] PrePackedWeights* prepacked_weights) override; @@ -29,6 +31,33 @@ class ConvTranspose : public CudaKernel { mutable CudnnConvState s_; std::unique_ptr W_; + + bool is_nhwc_domain_; // prepack is only needed for the Conv in kMSInternalNHWCDomain + bool is_fused_node_ = false; // ensures the node is fused although the session option is not set + bool W_already_nhwc = false; // In case NHWC == true and Conv is not in kMSInternalNHWCDomain + + protected: + inline IAllocatorUniquePtr GetWorkSpace(onnxruntime::Stream* stream) const { + return GetScratchBuffer(s_.workspace_bytes, stream); + } + + Status UpdateState(OpKernelContext* context, bool bias_expected) const; + +#if !defined(__CUDACC__) && CUDNN_MAJOR >= 9 + Status CreateCudnnFeExecutionPlan(const onnxruntime::TensorShapeVector& x_dims, + const onnxruntime::TensorShapeVector& w_dims, + const Tensor* B, + const TensorShapeVector& y_dims, + cudnnContext* handle, + const cudnn_frontend::HeurMode_t heur_mode, + const std::vector& pads, + const std::vector& strides, + const std::vector& dilations, + const bool fuse_bias, + const bool fuse_act, + const bool w_in_nhwc, + const bool use_tf32) const; +#endif }; } // namespace cuda diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose_8.h b/onnxruntime/core/providers/cuda/nn/conv_transpose_8.h new file mode 100644 index 0000000000000..b46d41b887e41 --- /dev/null +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose_8.h @@ -0,0 +1,266 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. +// Licensed under the MIT License. + +#include + +#include "conv_transpose.h" +#include + +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/cuda_kernel.h" +#include "core/providers/cuda/cudnn_common.h" +#include "core/providers/cuda/nn/conv.h" +#include "core/providers/cpu/nn/conv_transpose_attributes.h" + +#include "core/providers/cuda/tensor/transpose.h" + +// To suppress FP static analyzer warnings: +// https://msdata.visualstudio.com/Vienna/_workitems/edit/1944928 and +// https://msdata.visualstudio.com/Vienna/_workitems/edit/1944950 +#ifdef _WIN32 +#pragma warning(push) +#pragma warning(disable : 26110) +#pragma warning(disable : 26117) +#endif + +namespace onnxruntime { +namespace cuda { + +template +Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_padding) const { + const Tensor* X = context->Input(0); + const TensorShape& x_shape = X->Shape(); + auto x_dims = x_shape.AsShapeVector(); + auto x_data = reinterpret_cast(X->Data()); + + auto x_dimensions = X->Shape().NumDimensions(); + if (x_dimensions < 3 || x_dimensions > 5) { + // TODO: the error message should tell which operator raises it. + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input X must be 3-, 4- or 5-dimensional.", + " X: ", X->Shape().ToString().c_str()); + } + + // use pre-packed W if available + const Tensor* W = W_ ? W_.get() : context->Input(1); + + const TensorShape& w_shape = W->Shape(); + TensorShapeVector w_dims = w_shape.AsShapeVector(); + auto w_data = reinterpret_cast(W->Data()); + + size_t num_inputs = OpKernel::Node().InputDefs().size(); + bool has_bias = dynamic_padding ? num_inputs == 4 : num_inputs == 3; + + CudaT* y_data = nullptr; + + const auto* cuda_ep = static_cast(Info().GetExecutionProvider()); + + // convert 1D to 2D + if (x_dimensions == 3) { + // we can either add a fake H or W dimension with a value of 1. to be consistent with the Conv behavior we use + // GetCudnnConv1dPadToNc1d to determine which is added. + // see Conv::UpdateState in /onnxruntime/core/providers/cuda/nn/conv.cc for more details. + if (cuda_ep->GetCudnnConv1dPadToNc1d()) { + // add fake H dimension + const auto insert_at = NHWC ? 1 : 2; + + // NCHW: N, C, d1 -> N, C, 1, d1 + // NHWC: N, d1, C -> N, 1, d1, C + x_dims.insert(x_dims.begin() + insert_at, 1); + + // 'M' is channels dim in CUDA implementation + // NCHW: C, M/g, k1 -> C, M/g, 1, k1 + // NHWC: C, k1, M/g -> C, 1, k1, M/g + w_dims.insert(w_dims.begin() + insert_at, 1); + } else { + // add fake W dimension + const auto insert_at = NHWC ? 2 : 3; + + // NCHW: N, C, d1 -> N, C, d1, 1 + // NHWC: N, d1, C -> N, d1, 1, C + x_dims.insert(x_dims.begin() + insert_at, 1); + + // NCHW: C, M/g, k1 -> C, M/g, k1, 1 + // NHWC: C, k1, M/g -> C, k1, 1, M/g + w_dims.insert(w_dims.begin() + insert_at, 1); + } + } + + { + std::lock_guard lock(s_.mutex); + // CUDNN_CONFIG_RETURN_IF_ERROR(cudnnSetStream(CudnnHandle(), Stream(context))); + // TODO: add a global cache if need to handle cases for multiple frames running simultaneously with + // different batch_size + bool input_dims_changed = (s_.last_x_dims.AsShapeVector() != x_dims); + bool w_dims_changed = (s_.last_w_dims.AsShapeVector() != w_dims); + if (input_dims_changed || w_dims_changed) { + if (input_dims_changed) { + s_.last_x_dims = gsl::make_span(x_dims); + } + + if (w_dims_changed) { + s_.last_w_dims = gsl::make_span(w_dims); + s_.cached_benchmark_results.clear(); + } + + ConvTransposeAttributes::Prepare p; + // PrePack moves the M/group dimension of W to the end, with 'M' being interpreted as 'output channels' + const bool transposed_input_channels = false; + ORT_RETURN_IF_ERROR( + conv_transpose_attrs_.PrepareForCompute(context, has_bias, p, dynamic_padding, + &w_shape, NHWC, transposed_input_channels)); + + auto y_dims = p.Y->Shape().AsShapeVector(); + if (x_dimensions == 3) { + if (cuda_ep->GetCudnnConv1dPadToNc1d()) { + // add fake H dimension of 1 + // NCHW: N, M, d1 -> N, M, 1, d1 or + // NHWC: N, d1, M -> N, 1, d1, M + y_dims.insert(y_dims.begin() + (NHWC ? 1 : 2), 1); + p.kernel_shape.insert(p.kernel_shape.begin(), 1); + p.pads.insert(p.pads.begin(), 0); + p.pads.insert(p.pads.begin() + 2, 0); + p.strides.insert(p.strides.begin(), 1); + p.dilations.insert(p.dilations.begin(), 1); + } else { + // add fake W dimension of 1 + // NCHW: N, M, d1 -> N, M, d1, 1 or + // NHWC: N, d1, M -> N, d1, 1, M + y_dims.insert(y_dims.begin() + (NHWC ? 2 : 3), 1); + p.kernel_shape.push_back(1); + p.pads.insert(p.pads.begin() + 1, 0); + p.pads.push_back(0); + p.strides.push_back(1); + p.dilations.push_back(1); + } + } + + s_.y_dims = gsl::make_span(y_dims); + + if (w_dims_changed) { + if constexpr (NHWC) { + ORT_RETURN_IF_ERROR(s_.w_desc.Set(CUDNN_TENSOR_NHWC, CudnnTensor::GetDataType(), + static_cast(w_dims[0]), static_cast(w_dims[3]), + static_cast(w_dims[1]), static_cast(w_dims[2]))); + } else { + ORT_RETURN_IF_ERROR(s_.w_desc.Set(w_dims, CudnnTensor::GetDataType())); + } + } + + // Special case when there is a dim value of 0 in the shape. + // Return only after we have cached the following for subsequent runs : + // 1) `w_dims` in the `w_desc` + // 2) `y_dims` in s_.y_dims + if (p.Y->Shape().Size() == 0) { + return Status::OK(); + } + + if constexpr (NHWC) { + ORT_RETURN_IF_ERROR(s_.x_tensor.Set(CUDNN_TENSOR_NHWC, CudnnTensor::GetDataType(), + static_cast(x_dims[0]), static_cast(x_dims[3]), + static_cast(x_dims[1]), static_cast(x_dims[2]))); + ORT_RETURN_IF_ERROR(s_.y_tensor.Set(CUDNN_TENSOR_NHWC, CudnnTensor::GetDataType(), + static_cast(y_dims[0]), static_cast(y_dims[3]), + static_cast(y_dims[1]), static_cast(y_dims[2]))); + } else { + ORT_RETURN_IF_ERROR(s_.x_tensor.Set(x_dims, CudnnTensor::GetDataType())); + ORT_RETURN_IF_ERROR(s_.y_tensor.Set(y_dims, CudnnTensor::GetDataType())); + } + + cudnnConvolutionMode_t mode = CUDNN_CROSS_CORRELATION; + ORT_RETURN_IF_ERROR(s_.conv_desc.Set(p.kernel_shape.size(), p.pads, p.strides, p.dilations, + gsl::narrow_cast(conv_transpose_attrs_.group), mode, + CudnnTensor::GetDataType(), + UseTF32())); + + if (has_bias) { + const auto& b_shape = p.B->Shape(); + ORT_RETURN_IF_NOT(b_shape.NumDimensions() == 1, "bias should be 1D"); + TensorShapeVector b_dims(2 + p.kernel_shape.size()); + b_dims[0] = 1; // N + b_dims[NHWC ? 3 : 1] = b_shape[0]; // C + for (size_t i = 0; i < p.kernel_shape.size(); i++) { + b_dims[(NHWC ? 1 : 2) + i] = 1; + } + + ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_dims, CudnnTensor::GetDataType(), NHWC)); + } + + y_data = reinterpret_cast(p.Y->MutableData()); + + if (!s_.cached_benchmark_results.contains(x_dims)) { + IAllocatorUniquePtr algo_search_workspace = + GetScratchBuffer(AlgoSearchWorkspaceSize, context->GetComputeStream()); + + // set math type to tensor core before algorithm search + if constexpr (std::is_same::value) { + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_TENSOR_OP_MATH)); + } else if constexpr (std::is_same::value) { + if (!UseTF32()) { + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_FMA_MATH)); + } + } + + cudnnConvolutionBwdDataAlgoPerf_t perf; + int algo_count = 1; + CUDNN_RETURN_IF_ERROR(cudnnFindConvolutionBackwardDataAlgorithmEx( + GetCudnnHandle(context), s_.w_desc, w_data, s_.x_tensor, x_data, s_.conv_desc, s_.y_tensor, y_data, 1, + &algo_count, &perf, algo_search_workspace.get(), AlgoSearchWorkspaceSize)); + s_.cached_benchmark_results.insert(x_dims, {perf.algo, perf.memory, perf.mathType}); + } + + const auto& perf = s_.cached_benchmark_results.at(x_dims); + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, perf.mathType)); + s_.algo = perf.algo; + s_.workspace_bytes = perf.memory; + } + + // The following block will be executed in case there has been no change in the shapes of the + // input and the filter compared to the previous run + if (!y_data) { + auto y_dims = s_.y_dims.AsShapeVector(); + if (x_dimensions == 3) { + if (cuda_ep->GetCudnnConv1dPadToNc1d()) { + // erase the fake H dimension + y_dims.erase(y_dims.begin() + (NHWC ? 1 : 2)); + } else { + // erase the fake W dimension + y_dims.erase(y_dims.begin() + (NHWC ? 2 : 3)); + } + } + + Tensor* Y = context->Output(0, TensorShape(y_dims)); + y_data = reinterpret_cast(Y->MutableData()); + + // Bail out early if one of the output dimensions is zero. + if (Y->Shape().Size() == 0) { + return Status::OK(); + } + } + + const auto alpha = Consts::One; + const auto beta = Consts::Zero; + + IAllocatorUniquePtr workspace = GetScratchBuffer(s_.workspace_bytes, context->GetComputeStream()); + + CUDNN_RETURN_IF_ERROR(cudnnConvolutionBackwardData(GetCudnnHandle(context), &alpha, s_.w_desc, w_data, + s_.x_tensor, x_data, s_.conv_desc, s_.algo, workspace.get(), + s_.workspace_bytes, &beta, s_.y_tensor, y_data)); + + if (has_bias) { + const Tensor* B = dynamic_padding ? context->Input(3) : context->Input(2); + auto b_data = reinterpret_cast(B->Data()); + CUDNN_RETURN_IF_ERROR( + cudnnAddTensor(GetCudnnHandle(context), &alpha, s_.b_tensor, b_data, &alpha, s_.y_tensor, y_data)); + } + } + + return Status::OK(); +} + +} // namespace cuda +} // namespace onnxruntime + +#ifdef _WIN32 +#pragma warning(pop) +#endif