Skip to content

Commit

Permalink
Update trt_bert_creator.h
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanzexi committed Aug 17, 2021
1 parent d22837e commit f247706
Showing 1 changed file with 2 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ class TLayerCreator<TrtBertDesc> : public ILayerCreator {
if (bert_desc->hidden_size != 768 && bert_desc->hidden_size != 1024) use_int8 = false;

const auto dtype = TrtCommon::GetDataType(bert_desc->use_fp16, use_int8, bert_desc->calib_mode);
const int has_skip = 1;

nvinfer1::IPluginCreator* creator = getPluginRegistry()->getPluginCreator(
fwd::bert::FWD_SKIP_LAYER_NORM_NAME, fwd::bert::FWD_SKIP_LAYER_NORM_VERSION);
Expand All @@ -223,6 +224,7 @@ class TLayerCreator<TrtBertDesc> : public ILayerCreator {
gamma.Count());
field_data.emplace_back("ld", &bert_desc->hidden_size, nvinfer1::PluginFieldType::kINT32, 1);
field_data.emplace_back("type_id", &dtype, nvinfer1::PluginFieldType::kINT32, 1);
field_data.emplace_back("has_skip", &has_skip, nvinfer1::PluginFieldType::kINT32, 1);

const nvinfer1::PluginFieldCollection plugin_data{static_cast<int>(field_data.size()),
field_data.data()};
Expand Down

0 comments on commit f247706

Please sign in to comment.