diff --git a/source/tnn/device/cuda/acc/cuda_splitv_layer_acc.cu b/source/tnn/device/cuda/acc/cuda_splitv_layer_acc.cu index 7088bc11d..04c7ffd72 100644 --- a/source/tnn/device/cuda/acc/cuda_splitv_layer_acc.cu +++ b/source/tnn/device/cuda/acc/cuda_splitv_layer_acc.cu @@ -19,6 +19,8 @@ namespace TNN_NS { DECLARE_CUDA_ACC(SplitV, LAYER_SPLITV); +int SPLITV_GRID_DIM_Y_MAX = 65535; + template<int THREAD_PER_BLOCK, int ELE_PER_THREAD> __global__ void splitv_separate_kernel( const float * __restrict__ src, float * dst, @@ -44,6 +46,32 @@ __global__ void splitv_separate_kernel( } +// Cases when real_grid_dim_y > SPLITV_GRID_DIM_Y_MAX +template<int THREAD_PER_BLOCK, int ELE_PER_THREAD> +__global__ void splitv_separate_ylarge_kernel( + const float * __restrict__ src, float * dst, + const int inner_size, const int in_stride, + const int split_start, const int split_end, const int real_grid_dim_y, const int GRID_DIM_Y_MAX) +{ + for (int block_idx_y = blockIdx.y; block_idx_y < real_grid_dim_y; block_idx_y += GRID_DIM_Y_MAX) { + int block_offset = blockIdx.x * THREAD_PER_BLOCK * ELE_PER_THREAD; + + const int split_size = split_end - split_start; + const int size = split_size * inner_size; + const float* src_offsetted = src + (blockIdx.z * real_grid_dim_y + block_idx_y) * in_stride; + float* dst_offsetted = dst + (blockIdx.z * real_grid_dim_y + block_idx_y) * size; + + #pragma unroll + for (int i = 0; i < ELE_PER_THREAD ; i++) { + int index = block_offset + i * THREAD_PER_BLOCK + threadIdx.x; + if (index < size) { + int input_index = index + split_start * inner_size; + dst_offsetted[index] = __ldg(src_offsetted + input_index); + } + } + } +} + Status CudaSplitVLayerAcc::Init(Context *context, LayerParam *param, LayerResource *resource, const std::vector<Blob *> &inputs, const std::vector<Blob *> &outputs) { CudaLayerAcc::Init(context, param, resource, inputs, outputs); @@ -88,8 +116,15 @@ Status CudaSplitVLayerAcc::Forward(const std::vector<Blob *> &inputs, const std: griddim.z = DimsVectorUtils::Count(dims, 0, min(axis, 1)); float* output_data = static_cast<float*>(output_blob->GetHandle().base); - splitv_separate_kernel<THREAD_PER_BLOCK, ELE_PER_THREAD><<<griddim, THREAD_PER_BLOCK, 0, context_->GetStream()>>> + if (griddim.y <= SPLITV_GRID_DIM_Y_MAX) { + splitv_separate_kernel<THREAD_PER_BLOCK, ELE_PER_THREAD><<<griddim, THREAD_PER_BLOCK, 0, context_->GetStream()>>> (input_data, output_data, inner_size, in_stride, split_begin, split_end); + } else { + int real_grid_dim_y = griddim.y; + griddim.y = SPLITV_GRID_DIM_Y_MAX; + splitv_separate_ylarge_kernel<THREAD_PER_BLOCK, ELE_PER_THREAD><<<griddim, THREAD_PER_BLOCK, 0, context_->GetStream()>>> + (input_data, output_data, inner_size, in_stride, split_begin, split_end, real_grid_dim_y, SPLITV_GRID_DIM_Y_MAX); + } split_begin = split_end; } }