Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix translate pad op && batchnormal layer fp16 codegen #435

Merged
merged 5 commits into from
Jul 19, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions src/nnfusion/core/kernels/cuda_gpu/kernels/batch_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,22 @@ LanguageUnit_p cuda::BatchNormNCHW::emit_function_body()
{
lu << alpha << " * ";
}
lu << "(input1[c_id] + (input0[c_id] * "
"(input2[st + i] - input3[c_id]) / sqrtf("
<< epsilon << " + input4[c_id])));\n";
// lu << "output0[st + i] = " << beta << " * output0[st + i] + " << alpha
// << " * (input1[c_id] + (input0[c_id] * "
// "(input2[st + i] - input3[c_id]) / sqrtf("
// << epsilon << " + input4[c_id])));\n";
/*
* todo: may have better solution, details in https://github.com/microsoft/nnfusion/issues/434
* */
if(dtype == nnfusion::element::f16){
lu << "output0[st + i] = __hadd(input1[c_id] , __hdiv(__hmul(input0[c_id], __hsub(input2[st + i], input3[c_id])), sqrtf(__hadd(__float2half("
<< epsilon << "), input4[c_id]))));\n";
}else{
lu << "(input1[c_id] + (input0[c_id] * "
"(input2[st + i] - input3[c_id]) / sqrtf("
<< epsilon << " + input4[c_id])));\n";
// lu << "output0[st + i] = " << beta << " * output0[st + i] + " << alpha
// << " * (input1[c_id] + (input0[c_id] * "
// "(input2[st + i] - input3[c_id]) / sqrtf("
// << epsilon << " + input4[c_id])));\n";
}

lu.block_end();

return _lu;
Expand Down
146 changes: 73 additions & 73 deletions src/nnfusion/core/kernels/cuda_gpu/kernels/dot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,83 +207,83 @@ LanguageUnit_p cuda::Dot::emit_function_body()
else if (dtype == element::f16)
{
// case 1: Scalar * Tensor
// if (arg0_shape.empty() || arg1_shape.empty())
// {
// auto& second = (arg0_shape.empty() ? arg1_shape : arg0_shape);
// size_t count = nnfusion::shape_size(second);
if (arg0_shape.empty() || arg1_shape.empty())
{
auto& second = (arg0_shape.empty() ? arg1_shape : arg0_shape);
size_t count = nnfusion::shape_size(second);

// string firstarg = (arg0_shape.empty() ? "input1" : "input0");
// string secondarg = (arg0_shape.empty() ? "input0" : "input1");
string firstarg = (arg0_shape.empty() ? "input1" : "input0");
string secondarg = (arg0_shape.empty() ? "input0" : "input1");

// lu << "cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_DEVICE);\n";
lu << "cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_DEVICE);\n";

// lu << "CUDA_SAFE_CALL(cudaMemcpy(outupt0, " << firstarg << ", " << count << ", cudaMemcpyDeviceToDevice));\n"; // copy `firstarg` to `output0`
// lu << "CUBLAS_SAFE_CALL(nnfusionHalfScale(" << secondarg << ", output0, " << count << "));\n";
// }
lu << "CUDA_SAFE_CALL(cudaMemcpy(outupt0, " << firstarg << ", " << count << ", cudaMemcpyDeviceToDevice));\n"; // copy `firstarg` to `output0`
lu << "CUBLAS_SAFE_CALL(nnfusionHalfScale(" << secondarg << ", output0, " << count << "));\n";
}
// // case 2: 1d Dot
// else if ((arg0_shape.size() == arg1_shape.size()) && (arg0_shape.size() == reduction_axes))
// {
// for (int i = 0; i < arg0_shape.size(); i++)
// {
// if (arg0_shape[i] != arg1_shape[i])
// {
// std::vector<std::string> arg_vec{"arg0", "arg1"};
// std::vector<nnfusion::Shape> shape_vec{arg0_shape, arg1_shape};

// NNFUSION_CHECK_FAIL() << nnfusion::join(arg_vec) << " with "
// << nnfusion::join(shape_vec) << " respectively, at Node "
// << m_context->gnode->get_name()
// << ", do not match for dot op";
// }
// }

// size_t count = nnfusion::shape_size(arg0_shape);
// lu << "cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_DEVICE);\n";

// lu << "CUBLAS_SAFE_CALL(cublasSdot(cublas_handle, " << count
// << ", static_cast<const float*>(input0), 1, static_cast<const float*>(input1), 1, "
// "static_cast<float*>(output0)));\n";
// }
else if ((arg0_shape.size() == arg1_shape.size()) && (arg0_shape.size() == reduction_axes))
{
for (int i = 0; i < arg0_shape.size(); i++)
{
if (arg0_shape[i] != arg1_shape[i])
{
std::vector<std::string> arg_vec{"arg0", "arg1"};
std::vector<nnfusion::Shape> shape_vec{arg0_shape, arg1_shape};

NNFUSION_CHECK_FAIL() << nnfusion::join(arg_vec) << " with "
<< nnfusion::join(shape_vec) << " respectively, at Node "
<< m_context->gnode->get_name()
<< ", do not match for dot op";
}
}

size_t count = nnfusion::shape_size(arg0_shape);
lu << "cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_DEVICE);\n";

lu << "CUBLAS_SAFE_CALL(cublasSdot(cublas_handle, " << count
<< ", static_cast<const float*>(input0), 1, static_cast<const float*>(input1), 1, "
"static_cast<float*>(output0)));\n";
}
// // matrix * vector
// else if ((arg0_shape.size() == 2) && (arg1_shape.size() == 1) && (reduction_axes == 1))
// {
// lu << "const float alpha = 1.0;\n const float beta = 0;\n";
// lu << "CUBLAS_SAFE_CALL(cublasSgemv(cublas_handle, ";
// if (trans_A)
// lu << "CUBLAS_OP_N, " << arg0_shape[0] << ", " << arg0_shape[1] << ", ";
// else
// lu << "CUBLAS_OP_T, " << arg0_shape[1] << ", " << arg0_shape[0] << ", ";
// lu << " &alpha,"
// << " static_cast<const float*>(input0)," << arg0_shape[1] << ", "
// << " static_cast<const float*>(input1),"
// << " 1,"
// << " &beta,"
// << " static_cast<float*>(output0),"
// << " 1));\n";
// }
// else if ((arg0_shape.size() == 2) && (arg1_shape.size() == 2) && (reduction_axes == 1) &&
// (trans_A || trans_B))
// {
// int m = trans_B ? arg1_shape[0] : arg1_shape[1];
// int n = trans_A ? arg0_shape[1] : arg0_shape[0];
// int k = trans_A ? arg0_shape[0] : arg0_shape[1];

// lu << "const half alpha = 1.0;\nconst half beta = 0;\n";

// lu << "CUBLAS_SAFE_CALL(cublasHgemm(cublas_handle,"
// << (trans_B ? " CUBLAS_OP_T," : " CUBLAS_OP_N,")
// << (trans_A ? " CUBLAS_OP_T," : " CUBLAS_OP_N,") << " " << m << ","
// << " " << n << ","
// << " " << k << ","
// << " &alpha,"
// << " static_cast<const half*>(input1),"
// << " " << arg1_shape[1] << ","
// << " static_cast<const half*>(input0),"
// << " " << arg0_shape[1] << ","
// << " &beta,"
// << " static_cast<half*>(output0),"
// << " " << m << "));\n";
// } else {
else if ((arg0_shape.size() == 2) && (arg1_shape.size() == 1) && (reduction_axes == 1))
{
lu << "const float alpha = 1.0;\n const float beta = 0;\n";
lu << "CUBLAS_SAFE_CALL(cublasSgemv(cublas_handle, ";
if (trans_A)
lu << "CUBLAS_OP_N, " << arg0_shape[0] << ", " << arg0_shape[1] << ", ";
else
lu << "CUBLAS_OP_T, " << arg0_shape[1] << ", " << arg0_shape[0] << ", ";
lu << " &alpha,"
<< " static_cast<const float*>(input0)," << arg0_shape[1] << ", "
<< " static_cast<const float*>(input1),"
<< " 1,"
<< " &beta,"
<< " static_cast<float*>(output0),"
<< " 1));\n";
}
else if ((arg0_shape.size() == 2) && (arg1_shape.size() == 2) && (reduction_axes == 1) &&
(trans_A || trans_B))
{
int m = trans_B ? arg1_shape[0] : arg1_shape[1];
int n = trans_A ? arg0_shape[1] : arg0_shape[0];
int k = trans_A ? arg0_shape[0] : arg0_shape[1];

lu << "const half alpha = 1.0;\nconst half beta = 0;\n";

lu << "CUBLAS_SAFE_CALL(cublasHgemm(cublas_handle,"
<< (trans_B ? " CUBLAS_OP_T," : " CUBLAS_OP_N,")
<< (trans_A ? " CUBLAS_OP_T," : " CUBLAS_OP_N,") << " " << m << ","
<< " " << n << ","
<< " " << k << ","
<< " &alpha,"
<< " static_cast<const half*>(input1),"
<< " " << arg1_shape[1] << ","
<< " static_cast<const half*>(input0),"
<< " " << arg0_shape[1] << ","
<< " &beta,"
<< " static_cast<half*>(output0),"
<< " " << m << "));\n";
} else {
size_t axes_for_m_count = arg0_shape.size() - reduction_axes;
size_t axes_for_n_count = arg1_shape.size() - reduction_axes;
size_t axes_for_k_count = reduction_axes;
Expand Down Expand Up @@ -361,7 +361,7 @@ LanguageUnit_p cuda::Dot::emit_function_body()
<< " &beta,"
<< " static_cast<half*>(output0),"
<< " " << n << "));\n";
// }
}
}
else
{
Expand Down
109 changes: 81 additions & 28 deletions src/nnfusion/frontend/onnx_import/op/pad.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,42 +19,95 @@ namespace nnfusion
const NodeMap& all_ng_nodes,
std::shared_ptr<nnfusion::graph::Graph> m_graph)
{
auto input_gnode = GetInputNode(all_ng_nodes, node_proto, 0);
auto padding_gnode = GetInputNode(all_ng_nodes, node_proto, 1);
/*
* since opset 11, 'pads' and 'value' have been moved from attributes to inputs
* */
if (node_proto.attribute_size() == 1){

std::vector<int64> paddings;
bool status = GetValueFromNGraphOp<int64>(padding_gnode, &paddings);
NNFUSION_CHECK(status);
auto input_gnode = GetInputNode(all_ng_nodes, node_proto, 0);
auto padding_gnode = GetInputNode(all_ng_nodes, node_proto, 1);
std::vector<int64> paddings;
bool status = GetValueFromNGraphOp<int64>(padding_gnode, &paddings);
NNFUSION_CHECK(status);

NNFUSION_CHECK(paddings.size() % 2 == 0)
<< "Constant node for paddings does not have an even number of elements";
NNFUSION_CHECK(paddings.size() % 2 == 0)
<< "Constant node for paddings does not have an even number of elements";

nnfusion::Shape padding_below(paddings.size() / 2);
nnfusion::Shape padding_above(paddings.size() / 2);
nnfusion::Shape padding_interior(paddings.size() / 2);
nnfusion::Shape padding_below(paddings.size() / 2);
nnfusion::Shape padding_above(paddings.size() / 2);
nnfusion::Shape padding_interior(paddings.size() / 2);

for (size_t i = 0; i < paddings.size() / 2; i++)
{
padding_below[i] = paddings[i];
padding_above[i] = paddings[i + paddings.size() / 2];
padding_interior[i] = 0;
}
for (size_t i = 0; i < paddings.size() / 2; i++)
{
padding_below[i] = paddings[i];
padding_above[i] = paddings[i + paddings.size() / 2];
padding_interior[i] = 0;
}

auto pad_val_op =
std::make_shared<op::Constant>(input_gnode->get_element_type(),
nnfusion::Shape{},
std::vector<std::string>{"0"});
auto pad_val_gnode = m_graph->add_node_and_edge(pad_val_op, GNodeVector({}));

auto pad_op =
std::make_shared<op::Pad>(padding_below, padding_above, padding_interior);
pad_op->set_name(node_proto.output(0));

auto pad_val_op =
std::make_shared<op::Constant>(input_gnode->get_element_type(),
nnfusion::Shape{},
std::vector<std::string>{"0"});
auto pad_val_gnode = m_graph->add_node_and_edge(pad_val_op, GNodeVector({}));
auto pad_gnode =
m_graph->add_node_and_edge(pad_op, {input_gnode, pad_val_gnode});

auto pad_op =
std::make_shared<op::Pad>(padding_below, padding_above, padding_interior);
pad_op->set_name(node_proto.output(0));
NamedNodeVector ret{{node_proto.output(0), pad_gnode}};
return ret;
}else{
cout << "meet pad op" << endl;
/* for pad op, 0: mode, 1: pads, 2: constant
* we can use attr.name() to get the name of the attr
* for mode, attr.s() represents name
* for pads, attr.ints(i) get's padding value
* for value, attr.f() get's value
*/
auto input_gnode = GetInputNode(all_ng_nodes, node_proto, 0);
const onnx::AttributeProto& modeAttr = node_proto.attribute(0);
cout << modeAttr.name() << endl;
if(modeAttr.s() != "constant") NNFUSION_CHECK_FAIL() << "unsupported padding type: " << modeAttr.s();
const onnx::AttributeProto& padAttr = node_proto.attribute(1);
cout << padAttr.name() << endl;
for (int i = 0; i < 8; ++i)
{
cout << padAttr.ints(i) << ' ';
}
cout << endl;
const onnx::AttributeProto& valueAttr = node_proto.attribute(2);
cout << valueAttr.name() << endl;
cout << valueAttr.f() << endl;
auto pad_val_op =
std::make_shared<op::Constant>(input_gnode->get_element_type(),
nnfusion::Shape{},
std::vector<std::string>{"0"});
auto pad_val_gnode = m_graph->add_node_and_edge(pad_val_op, GNodeVector({}));
nnfusion::Shape padding_below(4);
nnfusion::Shape padding_above(4);
nnfusion::Shape padding_interior(4);
/*
* fix: the correspondence of padding to pads is wrong
* */
for (int i = 0; i < 4; ++i)
{
padding_below[i] = padAttr.ints(i);
padding_above[i] = padAttr.ints(i+4);
}

auto pad_gnode =
m_graph->add_node_and_edge(pad_op, {input_gnode, pad_val_gnode});
auto pad_op =
std::make_shared<op::Pad>(padding_below, padding_above, padding_interior);
pad_op->set_name(node_proto.output(0));

NamedNodeVector ret{{node_proto.output(0), pad_gnode}};
return ret;
auto pad_gnode =
m_graph->add_node_and_edge(pad_op, {input_gnode, pad_val_gnode});

NamedNodeVector ret{{node_proto.output(0), pad_gnode}};
return ret;
}
}
} // namespace set_1
} //namespace onnx_import
Expand Down
2 changes: 1 addition & 1 deletion src/nnfusion/frontend/onnx_import/util/util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ namespace nnfusion
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT:
return make_constant_op<float>(element::f32, shape, tensor);
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16:
return make_constant_op<float>(element::f16, shape, tensor);
return make_constant_op<half_float::half>(element::f16, shape, tensor);
case onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE:
return make_constant_op<double>(element::f64, shape, tensor);
case onnx::TensorProto_DataType::TensorProto_DataType_INT8:
Expand Down
10 changes: 10 additions & 0 deletions src/nnfusion/frontend/onnx_import/util/util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,16 @@ namespace nnfusion
return std::vector<float>();
}

template <>
inline std::vector<half_float::half> get_data(const onnx::TensorProto& tensor){

if (tensor.has_raw_data())
{
return __get_raw_data<half_float::half>(tensor.raw_data());
}

}

template <>
inline std::vector<int32_t> get_data(const onnx::TensorProto& tensor)
{
Expand Down