From 12130ab0c9b69546da8d1675f9d38d5dbb60d93c Mon Sep 17 00:00:00 2001 From: Lingxiao Ma Date: Thu, 20 May 2021 11:16:01 +0800 Subject: [PATCH] Optimize bert training (#253) * reduction axis split optimization for column reduce kernel emitter * enable bert fusion optimizations in bert training example * fix bug in reduce kernel emitter * format code * enable bert fusion optimizations in bert training example --- .../core/kernels/cuda_gpu/kernels/reduce.hpp | 159 ++++++++++++++---- src/python/example/bert.py | 2 + 2 files changed, 131 insertions(+), 30 deletions(-) diff --git a/src/nnfusion/core/kernels/cuda_gpu/kernels/reduce.hpp b/src/nnfusion/core/kernels/cuda_gpu/kernels/reduce.hpp index 97353e5e8..2ac920816 100644 --- a/src/nnfusion/core/kernels/cuda_gpu/kernels/reduce.hpp +++ b/src/nnfusion/core/kernels/cuda_gpu/kernels/reduce.hpp @@ -69,6 +69,39 @@ namespace nnfusion rank = input_shape.size(); out_rank = rank - reduce_rank; reduce_op = CudaOpMap::op; + reduction_split_factor = 32; + + { + // calculate reduction_split_number + reduction_split_number = 1; + if (reduce_op == "add") + { + // currently, reduction_split only supports add reduce_op + uint32_t reduction_loop_size = 1; + Shape reduce_flag(rank, 0); + for (auto a : reduce_axis) + { + reduce_flag[a] = 1; + } + for (int i = 0; i < rank; i++) + { + if (reduce_flag[i] != 0) + { + reduction_loop_size = input_shape[i]; + } + } + if ((reduction_loop_size % reduction_split_factor) == 0) + { + reduction_split_number = + reduction_loop_size / reduction_split_factor; + } + else + { + reduction_split_number = + reduction_loop_size / reduction_split_factor + 1; + } + } + } // use to determine if it is RowReduction @@ -230,6 +263,13 @@ if (thread_idx == 0) output0[block_idx] = val; uint32_t block_size_x = 64; uint32_t aligned_grid_size_x = align_to_block_size(nthreads, block_size_x); + // memset for grid reduction atomicAdd + if (reduction_split_number > 1) + { + m_context->outputs[0]->set_memset(true, 0); + m_context->outputs[0]->set_persistent(true); + } + auto expand_vector_uint32 = [](string name, vector& d) { stringstream ss; for (int i = 0; i < d.size(); i++) @@ -262,50 +302,107 @@ if (thread_idx == 0) output0[block_idx] = val; lu << output_type << " r" << init_value << ";\n"; - int64_t last_r_idx = static_cast(reduce_rank) - 1; - for (int64_t j = 0; j < last_r_idx; j++) + if (reduction_split_number > 1) { - lu << "for(int idx" << j << " = 0; idx" << j << "< reduce_shape" << j - << "; idx" << j << "++)\n"; - lu.block_begin(); - } - { - lu << "uint32_t reduce_idx = in_idx;\n"; + int64_t last_r_idx = static_cast(reduce_rank) - 1; for (int64_t j = 0; j < last_r_idx; j++) { - lu << "reduce_idx += idx" << j << " * reduce_strides" << j << ";\n"; + lu << "for(int idx" << j << " = 0; idx" << j << "< reduce_shape" + << j << "; idx" << j << "++)\n"; + lu.block_begin(); } - lu << "int idx" << last_r_idx << " = 0;\n"; - lu << "uint32_t step = reduce_strides" << last_r_idx << ";\n"; - // Unroll last reduction axis. - uint32_t unroll_num = 8; - uint32_t unroll_shift = 3; - lu << "for(; idx" << last_r_idx << " < (reduce_shape" << last_r_idx - << " >> " << unroll_shift << "); idx" << last_r_idx << "++)\n"; - lu.block_begin(); { - for (int k = 0; k < unroll_num; k++) + lu << "uint32_t reduce_idx = in_idx;\n"; + for (int64_t j = 0; j < last_r_idx; j++) + { + lu << "reduce_idx += idx" << j << " * reduce_strides" << j + << ";\n"; + } + lu << "int idx" << last_r_idx << " = 0;\n"; + lu << "uint32_t step = reduce_strides" << last_r_idx << ";\n"; + lu << "idx" << last_r_idx << " += " << reduction_split_factor + << " * blockIdx.y;\n"; + lu << "int idx_end = min(idx" << last_r_idx << " + " + << reduction_split_factor << ", (int)reduce_shape" << last_r_idx + << ");\n"; + /* // Unroll last reduction axis. + uint32_t unroll_num = 8; + uint32_t unroll_shift = 3; + lu << "for(; idx" << last_r_idx << " < (reduce_shape" << last_r_idx + << " >> " << unroll_shift << "); idx" << last_r_idx << "++)\n"; + lu.block_begin(); + { + for (int k = 0; k < unroll_num; k++) + { + lu << "r = " << reduce_op << "(r , input0[reduce_idx]);\n"; + lu << "reduce_idx += step;\n"; + } + } + lu.block_end(); + lu << "idx" << last_r_idx << " <<= " << unroll_shift << ";\n"; */ + lu << "for(; idx" << last_r_idx << " < idx_end; idx" << last_r_idx + << "++)\n"; + lu.block_begin(); { lu << "r = " << reduce_op << "(r , input0[reduce_idx]);\n"; lu << "reduce_idx += step;\n"; } + lu.block_end(); } - lu.block_end(); - lu << "idx" << last_r_idx << " <<= " << unroll_shift << ";\n"; - lu << "for(; idx" << last_r_idx << " < reduce_shape" << last_r_idx - << "; idx" << last_r_idx << "++)\n"; - lu.block_begin(); + for (int64_t j = 0; j < last_r_idx; j++) { - lu << "r = " << reduce_op << "(r , input0[reduce_idx]);\n"; - lu << "reduce_idx += step;\n"; + lu.block_end(); } - lu.block_end(); + lu << "atomicAdd(output0 + tid, r);\n"; } - for (int64_t j = 0; j < last_r_idx; j++) + else { - lu.block_end(); + int64_t last_r_idx = static_cast(reduce_rank) - 1; + for (int64_t j = 0; j < last_r_idx; j++) + { + lu << "for(int idx" << j << " = 0; idx" << j << "< reduce_shape" + << j << "; idx" << j << "++)\n"; + lu.block_begin(); + } + { + lu << "uint32_t reduce_idx = in_idx;\n"; + for (int64_t j = 0; j < last_r_idx; j++) + { + lu << "reduce_idx += idx" << j << " * reduce_strides" << j + << ";\n"; + } + lu << "int idx" << last_r_idx << " = 0;\n"; + lu << "uint32_t step = reduce_strides" << last_r_idx << ";\n"; + // Unroll last reduction axis. + uint32_t unroll_num = 8; + uint32_t unroll_shift = 3; + lu << "for(; idx" << last_r_idx << " < (reduce_shape" << last_r_idx + << " >> " << unroll_shift << "); idx" << last_r_idx << "++)\n"; + lu.block_begin(); + { + for (int k = 0; k < unroll_num; k++) + { + lu << "r = " << reduce_op << "(r , input0[reduce_idx]);\n"; + lu << "reduce_idx += step;\n"; + } + } + lu.block_end(); + lu << "idx" << last_r_idx << " <<= " << unroll_shift << ";\n"; + lu << "for(; idx" << last_r_idx << " < reduce_shape" << last_r_idx + << "; idx" << last_r_idx << "++)\n"; + lu.block_begin(); + { + lu << "r = " << reduce_op << "(r , input0[reduce_idx]);\n"; + lu << "reduce_idx += step;\n"; + } + lu.block_end(); + } + for (int64_t j = 0; j < last_r_idx; j++) + { + lu.block_end(); + } + lu << "output0[tid] = r;\n"; } - lu << "output0[tid] = r;\n"; } lu.block_end(); @@ -453,7 +550,7 @@ if (thread_idx == 0) output0[block_idx] = val; uint32_t block_size_x = 64; uint32_t aligned_grid_size_x = align_to_block_size(nthreads, block_size_x); - m_gridDim = dim3(aligned_grid_size_x, 1, 1); + m_gridDim = dim3(aligned_grid_size_x, reduction_split_number, 1); m_blockDim = dim3(block_size_x, 1, 1); } else @@ -494,6 +591,8 @@ if (thread_idx == 0) output0[block_idx] = val; string reduce_op, input_type, output_type, init_value; size_t height, width, expected_block_size; bool is_row_reduction; + size_t reduction_split_factor, + reduction_split_number; // split reduction axis for column reduction }; template diff --git a/src/python/example/bert.py b/src/python/example/bert.py index 4dbe42b46..f8ba67ea0 100644 --- a/src/python/example/bert.py +++ b/src/python/example/bert.py @@ -183,6 +183,8 @@ def train_bert(): }) + '\'', # training optimizer configs "blockfusion_level": 0, # TODO: fix blockfusion problem in bert training + "enable_all_bert_fusion": + True, # enable all bert fusion optimizations } trainer = Trainer(wrapper, device=device, codegen_flags=codegen_flags)