diff --git a/source/tnn/device/cuda/acc/cuda_splitv_layer_acc.cu b/source/tnn/device/cuda/acc/cuda_splitv_layer_acc.cu
index 7088bc11d..04c7ffd72 100644
--- a/source/tnn/device/cuda/acc/cuda_splitv_layer_acc.cu
+++ b/source/tnn/device/cuda/acc/cuda_splitv_layer_acc.cu
@@ -19,6 +19,8 @@ namespace TNN_NS {
 
 DECLARE_CUDA_ACC(SplitV, LAYER_SPLITV);
 
+int SPLITV_GRID_DIM_Y_MAX = 65535;
+
 template<int THREAD_PER_BLOCK, int ELE_PER_THREAD>
 __global__ void splitv_separate_kernel(
     const float * __restrict__ src, float * dst,
@@ -44,6 +46,32 @@ __global__ void splitv_separate_kernel(
 
 }
 
+// Cases when real_grid_dim_y > SPLITV_GRID_DIM_Y_MAX
+template<int THREAD_PER_BLOCK, int ELE_PER_THREAD>
+__global__ void splitv_separate_ylarge_kernel(
+    const float * __restrict__ src, float * dst,
+    const int inner_size, const int in_stride,
+    const int split_start, const int split_end, const int real_grid_dim_y, const int GRID_DIM_Y_MAX)
+{
+    for (int block_idx_y = blockIdx.y; block_idx_y < real_grid_dim_y; block_idx_y += GRID_DIM_Y_MAX) {
+        int block_offset = blockIdx.x * THREAD_PER_BLOCK * ELE_PER_THREAD;
+
+        const int split_size = split_end - split_start;
+        const int size = split_size * inner_size;
+        const float* src_offsetted = src + (blockIdx.z * real_grid_dim_y + block_idx_y) * in_stride;
+        float* dst_offsetted       = dst + (blockIdx.z * real_grid_dim_y + block_idx_y) * size;
+
+        #pragma unroll
+        for (int i = 0; i < ELE_PER_THREAD ; i++) {
+            int index = block_offset + i * THREAD_PER_BLOCK + threadIdx.x;
+            if (index < size) {
+                int input_index = index + split_start * inner_size;
+                dst_offsetted[index] = __ldg(src_offsetted + input_index);
+            }
+        }
+    }
+}
+
 Status CudaSplitVLayerAcc::Init(Context *context, LayerParam *param, LayerResource *resource,
         const std::vector<Blob *> &inputs, const std::vector<Blob *> &outputs) {
     CudaLayerAcc::Init(context, param, resource, inputs, outputs);
@@ -88,8 +116,15 @@ Status CudaSplitVLayerAcc::Forward(const std::vector<Blob *> &inputs, const std:
         griddim.z = DimsVectorUtils::Count(dims, 0, min(axis, 1));
 
         float* output_data = static_cast<float*>(output_blob->GetHandle().base);
-        splitv_separate_kernel<THREAD_PER_BLOCK, ELE_PER_THREAD><<<griddim, THREAD_PER_BLOCK, 0, context_->GetStream()>>>
+        if (griddim.y <= SPLITV_GRID_DIM_Y_MAX) {
+          splitv_separate_kernel<THREAD_PER_BLOCK, ELE_PER_THREAD><<<griddim, THREAD_PER_BLOCK, 0, context_->GetStream()>>>
             (input_data, output_data, inner_size, in_stride, split_begin, split_end);
+        } else {
+          int real_grid_dim_y = griddim.y;
+          griddim.y = SPLITV_GRID_DIM_Y_MAX;
+          splitv_separate_ylarge_kernel<THREAD_PER_BLOCK, ELE_PER_THREAD><<<griddim, THREAD_PER_BLOCK, 0, context_->GetStream()>>>
+            (input_data, output_data, inner_size, in_stride, split_begin, split_end, real_grid_dim_y, SPLITV_GRID_DIM_Y_MAX);
+        }
         split_begin = split_end;
       }
     }