Skip to content

Commit

Permalink
Optimize bert training (#253)
Browse files Browse the repository at this point in the history
* reduction axis split optimization for column reduce kernel emitter

* enable bert fusion optimizations in bert training example

* fix bug in reduce kernel emitter

* format code

* enable bert fusion optimizations in bert training example
  • Loading branch information
xysmlx authored May 20, 2021
1 parent 7a81160 commit 12130ab
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 30 deletions.
159 changes: 129 additions & 30 deletions src/nnfusion/core/kernels/cuda_gpu/kernels/reduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,39 @@ namespace nnfusion
rank = input_shape.size();
out_rank = rank - reduce_rank;
reduce_op = CudaOpMap<T>::op;
reduction_split_factor = 32;

{
// calculate reduction_split_number
reduction_split_number = 1;
if (reduce_op == "add")
{
// currently, reduction_split only supports add reduce_op
uint32_t reduction_loop_size = 1;
Shape reduce_flag(rank, 0);
for (auto a : reduce_axis)
{
reduce_flag[a] = 1;
}
for (int i = 0; i < rank; i++)
{
if (reduce_flag[i] != 0)
{
reduction_loop_size = input_shape[i];
}
}
if ((reduction_loop_size % reduction_split_factor) == 0)
{
reduction_split_number =
reduction_loop_size / reduction_split_factor;
}
else
{
reduction_split_number =
reduction_loop_size / reduction_split_factor + 1;
}
}
}

// use to determine if it is RowReduction

Expand Down Expand Up @@ -230,6 +263,13 @@ if (thread_idx == 0) output0[block_idx] = val;
uint32_t block_size_x = 64;
uint32_t aligned_grid_size_x = align_to_block_size(nthreads, block_size_x);

// memset for grid reduction atomicAdd
if (reduction_split_number > 1)
{
m_context->outputs[0]->set_memset(true, 0);
m_context->outputs[0]->set_persistent(true);
}

auto expand_vector_uint32 = [](string name, vector<uint32_t>& d) {
stringstream ss;
for (int i = 0; i < d.size(); i++)
Expand Down Expand Up @@ -262,50 +302,107 @@ if (thread_idx == 0) output0[block_idx] = val;

lu << output_type << " r" << init_value << ";\n";

int64_t last_r_idx = static_cast<int64_t>(reduce_rank) - 1;
for (int64_t j = 0; j < last_r_idx; j++)
if (reduction_split_number > 1)
{
lu << "for(int idx" << j << " = 0; idx" << j << "< reduce_shape" << j
<< "; idx" << j << "++)\n";
lu.block_begin();
}
{
lu << "uint32_t reduce_idx = in_idx;\n";
int64_t last_r_idx = static_cast<int64_t>(reduce_rank) - 1;
for (int64_t j = 0; j < last_r_idx; j++)
{
lu << "reduce_idx += idx" << j << " * reduce_strides" << j << ";\n";
lu << "for(int idx" << j << " = 0; idx" << j << "< reduce_shape"
<< j << "; idx" << j << "++)\n";
lu.block_begin();
}
lu << "int idx" << last_r_idx << " = 0;\n";
lu << "uint32_t step = reduce_strides" << last_r_idx << ";\n";
// Unroll last reduction axis.
uint32_t unroll_num = 8;
uint32_t unroll_shift = 3;
lu << "for(; idx" << last_r_idx << " < (reduce_shape" << last_r_idx
<< " >> " << unroll_shift << "); idx" << last_r_idx << "++)\n";
lu.block_begin();
{
for (int k = 0; k < unroll_num; k++)
lu << "uint32_t reduce_idx = in_idx;\n";
for (int64_t j = 0; j < last_r_idx; j++)
{
lu << "reduce_idx += idx" << j << " * reduce_strides" << j
<< ";\n";
}
lu << "int idx" << last_r_idx << " = 0;\n";
lu << "uint32_t step = reduce_strides" << last_r_idx << ";\n";
lu << "idx" << last_r_idx << " += " << reduction_split_factor
<< " * blockIdx.y;\n";
lu << "int idx_end = min(idx" << last_r_idx << " + "
<< reduction_split_factor << ", (int)reduce_shape" << last_r_idx
<< ");\n";
/* // Unroll last reduction axis.
uint32_t unroll_num = 8;
uint32_t unroll_shift = 3;
lu << "for(; idx" << last_r_idx << " < (reduce_shape" << last_r_idx
<< " >> " << unroll_shift << "); idx" << last_r_idx << "++)\n";
lu.block_begin();
{
for (int k = 0; k < unroll_num; k++)
{
lu << "r = " << reduce_op << "(r , input0[reduce_idx]);\n";
lu << "reduce_idx += step;\n";
}
}
lu.block_end();
lu << "idx" << last_r_idx << " <<= " << unroll_shift << ";\n"; */
lu << "for(; idx" << last_r_idx << " < idx_end; idx" << last_r_idx
<< "++)\n";
lu.block_begin();
{
lu << "r = " << reduce_op << "(r , input0[reduce_idx]);\n";
lu << "reduce_idx += step;\n";
}
lu.block_end();
}
lu.block_end();
lu << "idx" << last_r_idx << " <<= " << unroll_shift << ";\n";
lu << "for(; idx" << last_r_idx << " < reduce_shape" << last_r_idx
<< "; idx" << last_r_idx << "++)\n";
lu.block_begin();
for (int64_t j = 0; j < last_r_idx; j++)
{
lu << "r = " << reduce_op << "(r , input0[reduce_idx]);\n";
lu << "reduce_idx += step;\n";
lu.block_end();
}
lu.block_end();
lu << "atomicAdd(output0 + tid, r);\n";
}
for (int64_t j = 0; j < last_r_idx; j++)
else
{
lu.block_end();
int64_t last_r_idx = static_cast<int64_t>(reduce_rank) - 1;
for (int64_t j = 0; j < last_r_idx; j++)
{
lu << "for(int idx" << j << " = 0; idx" << j << "< reduce_shape"
<< j << "; idx" << j << "++)\n";
lu.block_begin();
}
{
lu << "uint32_t reduce_idx = in_idx;\n";
for (int64_t j = 0; j < last_r_idx; j++)
{
lu << "reduce_idx += idx" << j << " * reduce_strides" << j
<< ";\n";
}
lu << "int idx" << last_r_idx << " = 0;\n";
lu << "uint32_t step = reduce_strides" << last_r_idx << ";\n";
// Unroll last reduction axis.
uint32_t unroll_num = 8;
uint32_t unroll_shift = 3;
lu << "for(; idx" << last_r_idx << " < (reduce_shape" << last_r_idx
<< " >> " << unroll_shift << "); idx" << last_r_idx << "++)\n";
lu.block_begin();
{
for (int k = 0; k < unroll_num; k++)
{
lu << "r = " << reduce_op << "(r , input0[reduce_idx]);\n";
lu << "reduce_idx += step;\n";
}
}
lu.block_end();
lu << "idx" << last_r_idx << " <<= " << unroll_shift << ";\n";
lu << "for(; idx" << last_r_idx << " < reduce_shape" << last_r_idx
<< "; idx" << last_r_idx << "++)\n";
lu.block_begin();
{
lu << "r = " << reduce_op << "(r , input0[reduce_idx]);\n";
lu << "reduce_idx += step;\n";
}
lu.block_end();
}
for (int64_t j = 0; j < last_r_idx; j++)
{
lu.block_end();
}
lu << "output0[tid] = r;\n";
}
lu << "output0[tid] = r;\n";
}
lu.block_end();

Expand Down Expand Up @@ -453,7 +550,7 @@ if (thread_idx == 0) output0[block_idx] = val;
uint32_t block_size_x = 64;
uint32_t aligned_grid_size_x = align_to_block_size(nthreads, block_size_x);

m_gridDim = dim3(aligned_grid_size_x, 1, 1);
m_gridDim = dim3(aligned_grid_size_x, reduction_split_number, 1);
m_blockDim = dim3(block_size_x, 1, 1);
}
else
Expand Down Expand Up @@ -494,6 +591,8 @@ if (thread_idx == 0) output0[block_idx] = val;
string reduce_op, input_type, output_type, init_value;
size_t height, width, expected_block_size;
bool is_row_reduction;
size_t reduction_split_factor,
reduction_split_number; // split reduction axis for column reduction
};

template <class T>
Expand Down
2 changes: 2 additions & 0 deletions src/python/example/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ def train_bert():
}) + '\'', # training optimizer configs
"blockfusion_level":
0, # TODO: fix blockfusion problem in bert training
"enable_all_bert_fusion":
True, # enable all bert fusion optimizations
}
trainer = Trainer(wrapper, device=device, codegen_flags=codegen_flags)

Expand Down

0 comments on commit 12130ab

Please sign in to comment.