From c86297a1b2c0a933723b847b7b1b7bc17c9a8821 Mon Sep 17 00:00:00 2001 From: Jilong Xue Date: Tue, 19 Jul 2022 13:01:56 +0000 Subject: [PATCH] fix Where and ScatterND bugs; code style --- src/nnfusion/common/type/element_type.cpp | 2 +- .../core/kernels/cuda_gpu/cuda_cudnn.cpp | 7 +- .../kernels/cuda_gpu/kernels/batch_matmul.cpp | 9 +- .../kernels/cuda_gpu/kernels/batch_norm.cpp | 10 +- .../core/kernels/cuda_gpu/kernels/dot.cpp | 280 +++++++++--------- .../core/kernels/cuda_gpu/kernels/reduce.hpp | 6 +- .../generic_op_define/ScatterND.cpp | 2 +- .../generic_op/generic_op_define/Slice.cpp | 10 +- .../engine/pass/extract_graph_signature.cpp | 6 +- src/nnfusion/frontend/onnx_import/op/pad.hpp | 32 +- .../frontend/onnx_import/op/where.cpp | 5 + .../frontend/onnx_import/util/util.hpp | 15 +- 12 files changed, 207 insertions(+), 177 deletions(-) diff --git a/src/nnfusion/common/type/element_type.cpp b/src/nnfusion/common/type/element_type.cpp index 590fdd2b6..74420d28c 100644 --- a/src/nnfusion/common/type/element_type.cpp +++ b/src/nnfusion/common/type/element_type.cpp @@ -59,7 +59,7 @@ bool element::Type::nnfusion_element_type_to_dtype_string(const element::Type& n std::string& dtype) { if (ng_et == element::boolean) - dtype = "char"; + dtype = "int16"; else if (ng_et == element::character) dtype = "char"; else if (ng_et == element::f16) diff --git a/src/nnfusion/core/kernels/cuda_gpu/cuda_cudnn.cpp b/src/nnfusion/core/kernels/cuda_gpu/cuda_cudnn.cpp index 040349588..5fca1dd00 100644 --- a/src/nnfusion/core/kernels/cuda_gpu/cuda_cudnn.cpp +++ b/src/nnfusion/core/kernels/cuda_gpu/cuda_cudnn.cpp @@ -201,11 +201,12 @@ LanguageUnit_p cuda::get_cudnn_convolution_descriptor(const Shape& padding, << "window_dilation_strides_int, CUDNN_CROSS_CORRELATION, " << data_type << "));\n"; } - if(type == nnfusion::element::f16){ + if (type == nnfusion::element::f16) + { // half precision, use tensor core lu << "CUDNN_SAFE_CALL(cudnnSetConvolutionMathType(" << desc << ", " - << "CUDNN_TENSOR_OP_MATH" - << "));\n"; + << "CUDNN_TENSOR_OP_MATH" + << "));\n"; } return _lu; diff --git a/src/nnfusion/core/kernels/cuda_gpu/kernels/batch_matmul.cpp b/src/nnfusion/core/kernels/cuda_gpu/kernels/batch_matmul.cpp index dff602dc3..1741dbf40 100644 --- a/src/nnfusion/core/kernels/cuda_gpu/kernels/batch_matmul.cpp +++ b/src/nnfusion/core/kernels/cuda_gpu/kernels/batch_matmul.cpp @@ -103,8 +103,9 @@ namespace nnfusion @hCublas@, @transA@, @transB@, @m@, @n@, @k@, &alpha, input1, @lda@, @stride_a@, input0, @ldb@, @stride_b@, &beta, output0, @ldc@, @stride_c@, @batch@)); - )" : - R"( + )" + : + R"( static const float alpha = @alpha@F, beta = @beta@F; // if (!@hCublas@) // CUBLAS_SAFE_CALL(@api_create@(&@hCublas@)); @@ -116,7 +117,9 @@ namespace nnfusion { {"hCublas", "cublas_handle"}, {"api_create", "cublasCreate"}, - {"api_exec", dtype == nnfusion::element::f16 ? "cublasHgemmStridedBatched" : "cublasSgemmStridedBatched"}, + {"api_exec", + dtype == nnfusion::element::f16 ? "cublasHgemmStridedBatched" + : "cublasSgemmStridedBatched"}, {"transA", transB ? "CUBLAS_OP_T" : "CUBLAS_OP_N"}, {"transB", transA ? "CUBLAS_OP_T" : "CUBLAS_OP_N"}, {"alpha", alpha}, diff --git a/src/nnfusion/core/kernels/cuda_gpu/kernels/batch_norm.cpp b/src/nnfusion/core/kernels/cuda_gpu/kernels/batch_norm.cpp index 42404ac7b..a6c66dfdc 100644 --- a/src/nnfusion/core/kernels/cuda_gpu/kernels/batch_norm.cpp +++ b/src/nnfusion/core/kernels/cuda_gpu/kernels/batch_norm.cpp @@ -171,10 +171,14 @@ LanguageUnit_p cuda::BatchNormNCHW::emit_function_body() /* * todo: may have better solution, details in https://github.com/microsoft/nnfusion/issues/434 * */ - if(dtype == nnfusion::element::f16){ - lu << "output0[st + i] = __hadd(input1[c_id] , __hdiv(__hmul(input0[c_id], __hsub(input2[st + i], input3[c_id])), sqrtf(__hadd(__float2half(" + if (dtype == nnfusion::element::f16) + { + lu << "output0[st + i] = __hadd(input1[c_id] , __hdiv(__hmul(input0[c_id], " + "__hsub(input2[st + i], input3[c_id])), sqrtf(__hadd(__float2half(" << epsilon << "), input4[c_id]))));\n"; - }else{ + } + else + { lu << "(input1[c_id] + (input0[c_id] * " "(input2[st + i] - input3[c_id]) / sqrtf(" << epsilon << " + input4[c_id])));\n"; diff --git a/src/nnfusion/core/kernels/cuda_gpu/kernels/dot.cpp b/src/nnfusion/core/kernels/cuda_gpu/kernels/dot.cpp index 8e7b7a735..a78071f18 100644 --- a/src/nnfusion/core/kernels/cuda_gpu/kernels/dot.cpp +++ b/src/nnfusion/core/kernels/cuda_gpu/kernels/dot.cpp @@ -207,161 +207,165 @@ LanguageUnit_p cuda::Dot::emit_function_body() else if (dtype == element::f16) { // case 1: Scalar * Tensor - if (arg0_shape.empty() || arg1_shape.empty()) - { - auto& second = (arg0_shape.empty() ? arg1_shape : arg0_shape); - size_t count = nnfusion::shape_size(second); + if (arg0_shape.empty() || arg1_shape.empty()) + { + auto& second = (arg0_shape.empty() ? arg1_shape : arg0_shape); + size_t count = nnfusion::shape_size(second); - string firstarg = (arg0_shape.empty() ? "input1" : "input0"); - string secondarg = (arg0_shape.empty() ? "input0" : "input1"); + string firstarg = (arg0_shape.empty() ? "input1" : "input0"); + string secondarg = (arg0_shape.empty() ? "input0" : "input1"); - lu << "cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_DEVICE);\n"; + lu << "cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_DEVICE);\n"; - lu << "CUDA_SAFE_CALL(cudaMemcpy(outupt0, " << firstarg << ", " << count << ", cudaMemcpyDeviceToDevice));\n"; // copy `firstarg` to `output0` - lu << "CUBLAS_SAFE_CALL(nnfusionHalfScale(" << secondarg << ", output0, " << count << "));\n"; - } + lu << "CUDA_SAFE_CALL(cudaMemcpy(outupt0, " << firstarg << ", " << count + << ", cudaMemcpyDeviceToDevice));\n"; // copy `firstarg` to `output0` + lu << "CUBLAS_SAFE_CALL(nnfusionHalfScale(" << secondarg << ", output0, " << count + << "));\n"; + } // // case 2: 1d Dot - else if ((arg0_shape.size() == arg1_shape.size()) && (arg0_shape.size() == reduction_axes)) - { - for (int i = 0; i < arg0_shape.size(); i++) - { - if (arg0_shape[i] != arg1_shape[i]) - { - std::vector arg_vec{"arg0", "arg1"}; - std::vector shape_vec{arg0_shape, arg1_shape}; - - NNFUSION_CHECK_FAIL() << nnfusion::join(arg_vec) << " with " - << nnfusion::join(shape_vec) << " respectively, at Node " - << m_context->gnode->get_name() - << ", do not match for dot op"; - } - } - - size_t count = nnfusion::shape_size(arg0_shape); - lu << "cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_DEVICE);\n"; - - lu << "CUBLAS_SAFE_CALL(cublasSdot(cublas_handle, " << count - << ", static_cast(input0), 1, static_cast(input1), 1, " - "static_cast(output0)));\n"; - } - // // matrix * vector - else if ((arg0_shape.size() == 2) && (arg1_shape.size() == 1) && (reduction_axes == 1)) - { - lu << "const float alpha = 1.0;\n const float beta = 0;\n"; - lu << "CUBLAS_SAFE_CALL(cublasSgemv(cublas_handle, "; - if (trans_A) - lu << "CUBLAS_OP_N, " << arg0_shape[0] << ", " << arg0_shape[1] << ", "; - else - lu << "CUBLAS_OP_T, " << arg0_shape[1] << ", " << arg0_shape[0] << ", "; - lu << " &alpha," - << " static_cast(input0)," << arg0_shape[1] << ", " - << " static_cast(input1)," - << " 1," - << " &beta," - << " static_cast(output0)," - << " 1));\n"; - } - else if ((arg0_shape.size() == 2) && (arg1_shape.size() == 2) && (reduction_axes == 1) && - (trans_A || trans_B)) - { - int m = trans_B ? arg1_shape[0] : arg1_shape[1]; - int n = trans_A ? arg0_shape[1] : arg0_shape[0]; - int k = trans_A ? arg0_shape[0] : arg0_shape[1]; - - lu << "const half alpha = 1.0;\nconst half beta = 0;\n"; - - lu << "CUBLAS_SAFE_CALL(cublasHgemm(cublas_handle," - << (trans_B ? " CUBLAS_OP_T," : " CUBLAS_OP_N,") - << (trans_A ? " CUBLAS_OP_T," : " CUBLAS_OP_N,") << " " << m << "," - << " " << n << "," - << " " << k << "," - << " &alpha," - << " static_cast(input1)," - << " " << arg1_shape[1] << "," - << " static_cast(input0)," - << " " << arg0_shape[1] << "," - << " &beta," - << " static_cast(output0)," - << " " << m << "));\n"; - } else { - size_t axes_for_m_count = arg0_shape.size() - reduction_axes; - size_t axes_for_n_count = arg1_shape.size() - reduction_axes; - size_t axes_for_k_count = reduction_axes; - size_t m = 1; - size_t n = 1; - size_t k = 1; - - // check if input and output size correct - // check and calculate k for arg0 and arg1 - size_t arg0_k_idx = axes_for_m_count; // first axe in arg0 for k - size_t arg1_k_idx = 0; // first axe in arg1 for k - - for (size_t i = 0; i < axes_for_k_count; i++) + else if ((arg0_shape.size() == arg1_shape.size()) && (arg0_shape.size() == reduction_axes)) { - k *= arg0_shape[arg0_k_idx]; - if (arg0_shape[arg0_k_idx++] != arg1_shape[arg1_k_idx++]) + for (int i = 0; i < arg0_shape.size(); i++) { - std::vector arg_vec{"arg0", "arg1"}; - std::vector shape_vec{arg0_shape, arg1_shape}; + if (arg0_shape[i] != arg1_shape[i]) + { + std::vector arg_vec{"arg0", "arg1"}; + std::vector shape_vec{arg0_shape, arg1_shape}; - NNFUSION_CHECK_FAIL() << nnfusion::join(arg_vec) << " with " - << nnfusion::join(shape_vec) << " respectively, at Node " - << m_context->gnode->get_name() - << ", do not match for dot op"; + NNFUSION_CHECK_FAIL() << nnfusion::join(arg_vec) << " with " + << nnfusion::join(shape_vec) << " respectively, at Node " + << m_context->gnode->get_name() + << ", do not match for dot op"; + } } + + size_t count = nnfusion::shape_size(arg0_shape); + lu << "cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_DEVICE);\n"; + + lu << "CUBLAS_SAFE_CALL(cublasSdot(cublas_handle, " << count + << ", static_cast(input0), 1, static_cast(input1), 1, " + "static_cast(output0)));\n"; } - // check and calculate m for arg0 and out - size_t arg0_m_idx = 0; // first axe in arg0 for m - size_t out_m_idx = 0; // first axe in out for m - for (size_t i = 0; i < axes_for_m_count; i++) + // // matrix * vector + else if ((arg0_shape.size() == 2) && (arg1_shape.size() == 1) && (reduction_axes == 1)) { - m *= arg0_shape[arg0_m_idx]; - if (arg0_shape[arg0_m_idx++] != out_shape[out_m_idx++]) - { - std::vector arg_vec{"arg0", "output"}; - std::vector shape_vec{arg0_shape, out_shape}; + lu << "const float alpha = 1.0;\n const float beta = 0;\n"; + lu << "CUBLAS_SAFE_CALL(cublasSgemv(cublas_handle, "; + if (trans_A) + lu << "CUBLAS_OP_N, " << arg0_shape[0] << ", " << arg0_shape[1] << ", "; + else + lu << "CUBLAS_OP_T, " << arg0_shape[1] << ", " << arg0_shape[0] << ", "; + lu << " &alpha," + << " static_cast(input0)," << arg0_shape[1] << ", " + << " static_cast(input1)," + << " 1," + << " &beta," + << " static_cast(output0)," + << " 1));\n"; + } + else if ((arg0_shape.size() == 2) && (arg1_shape.size() == 2) && (reduction_axes == 1) && + (trans_A || trans_B)) + { + int m = trans_B ? arg1_shape[0] : arg1_shape[1]; + int n = trans_A ? arg0_shape[1] : arg0_shape[0]; + int k = trans_A ? arg0_shape[0] : arg0_shape[1]; - NNFUSION_CHECK_FAIL() << nnfusion::join(arg_vec) << " with " - << nnfusion::join(shape_vec) << " respectively, at Node " - << m_context->gnode->get_name() - << ", do not match for dot op"; - } + lu << "const half alpha = 1.0;\nconst half beta = 0;\n"; + + lu << "CUBLAS_SAFE_CALL(cublasHgemm(cublas_handle," + << (trans_B ? " CUBLAS_OP_T," : " CUBLAS_OP_N,") + << (trans_A ? " CUBLAS_OP_T," : " CUBLAS_OP_N,") << " " << m << "," + << " " << n << "," + << " " << k << "," + << " &alpha," + << " static_cast(input1)," + << " " << arg1_shape[1] << "," + << " static_cast(input0)," + << " " << arg0_shape[1] << "," + << " &beta," + << " static_cast(output0)," + << " " << m << "));\n"; } - // check and calculate n for arg1 and out - size_t arg1_n_idx = axes_for_k_count; // first axe in arg1 for n - size_t out_n_idx = axes_for_m_count; // first axe in arg1 for n - for (size_t i = 0; i < axes_for_n_count; i++) + else { - n *= arg1_shape[arg1_n_idx]; - if (arg1_shape[arg1_n_idx++] != out_shape[out_n_idx++]) + size_t axes_for_m_count = arg0_shape.size() - reduction_axes; + size_t axes_for_n_count = arg1_shape.size() - reduction_axes; + size_t axes_for_k_count = reduction_axes; + size_t m = 1; + size_t n = 1; + size_t k = 1; + + // check if input and output size correct + // check and calculate k for arg0 and arg1 + size_t arg0_k_idx = axes_for_m_count; // first axe in arg0 for k + size_t arg1_k_idx = 0; // first axe in arg1 for k + + for (size_t i = 0; i < axes_for_k_count; i++) { - std::vector arg_vec{"arg1", "output"}; - std::vector shape_vec{arg1_shape, out_shape}; + k *= arg0_shape[arg0_k_idx]; + if (arg0_shape[arg0_k_idx++] != arg1_shape[arg1_k_idx++]) + { + std::vector arg_vec{"arg0", "arg1"}; + std::vector shape_vec{arg0_shape, arg1_shape}; - NNFUSION_CHECK_FAIL() << nnfusion::join(arg_vec) << " with " - << nnfusion::join(shape_vec) << " respectively, at Node " - << m_context->gnode->get_name() - << ", do not match for dot op"; + NNFUSION_CHECK_FAIL() << nnfusion::join(arg_vec) << " with " + << nnfusion::join(shape_vec) << " respectively, at Node " + << m_context->gnode->get_name() + << ", do not match for dot op"; + } } - } + // check and calculate m for arg0 and out + size_t arg0_m_idx = 0; // first axe in arg0 for m + size_t out_m_idx = 0; // first axe in out for m + for (size_t i = 0; i < axes_for_m_count; i++) + { + m *= arg0_shape[arg0_m_idx]; + if (arg0_shape[arg0_m_idx++] != out_shape[out_m_idx++]) + { + std::vector arg_vec{"arg0", "output"}; + std::vector shape_vec{arg0_shape, out_shape}; + + NNFUSION_CHECK_FAIL() << nnfusion::join(arg_vec) << " with " + << nnfusion::join(shape_vec) << " respectively, at Node " + << m_context->gnode->get_name() + << ", do not match for dot op"; + } + } + // check and calculate n for arg1 and out + size_t arg1_n_idx = axes_for_k_count; // first axe in arg1 for n + size_t out_n_idx = axes_for_m_count; // first axe in arg1 for n + for (size_t i = 0; i < axes_for_n_count; i++) + { + n *= arg1_shape[arg1_n_idx]; + if (arg1_shape[arg1_n_idx++] != out_shape[out_n_idx++]) + { + std::vector arg_vec{"arg1", "output"}; + std::vector shape_vec{arg1_shape, out_shape}; - lu << "const half alpha = 1.0f;\nconst half beta = 0.f;\n"; - - lu << "CUBLAS_SAFE_CALL(cublasHgemm(cublas_handle," - << " CUBLAS_OP_N," - << " CUBLAS_OP_N," - << " " << n << "," - << " " << m << "," - << " " << k << "," - << " &alpha," - << " static_cast(input1)," - << " " << n << "," - << " static_cast(input0)," - << " " << k << "," - << " &beta," - << " static_cast(output0)," - << " " << n << "));\n"; - } + NNFUSION_CHECK_FAIL() << nnfusion::join(arg_vec) << " with " + << nnfusion::join(shape_vec) << " respectively, at Node " + << m_context->gnode->get_name() + << ", do not match for dot op"; + } + } + + lu << "const half alpha = 1.0f;\nconst half beta = 0.f;\n"; + + lu << "CUBLAS_SAFE_CALL(cublasHgemm(cublas_handle," + << " CUBLAS_OP_N," + << " CUBLAS_OP_N," + << " " << n << "," + << " " << m << "," + << " " << k << "," + << " &alpha," + << " static_cast(input1)," + << " " << n << "," + << " static_cast(input0)," + << " " << k << "," + << " &beta," + << " static_cast(output0)," + << " " << n << "));\n"; + } } else { diff --git a/src/nnfusion/core/kernels/cuda_gpu/kernels/reduce.hpp b/src/nnfusion/core/kernels/cuda_gpu/kernels/reduce.hpp index 5c9146afb..4c47ba346 100644 --- a/src/nnfusion/core/kernels/cuda_gpu/kernels/reduce.hpp +++ b/src/nnfusion/core/kernels/cuda_gpu/kernels/reduce.hpp @@ -230,7 +230,10 @@ for (int tidx = thread_idx; tidx < width; tidx += block_size) { val = reduceSum(val, thread_idx, block_size, shm); if (thread_idx == 0) output0[block_idx] = val; )", - {{"width", width}, {"block_size", expected_block_size}, {"warp_size", 32},{"dataType", dtype==nnfusion::element::f16? "half" : "float"}}); + {{"width", width}, + {"block_size", expected_block_size}, + {"warp_size", 32}, + {"dataType", dtype == nnfusion::element::f16 ? "half" : "float"}}); lu << code << "\n"; return _lu; @@ -582,7 +585,6 @@ if (thread_idx == 0) output0[block_idx] = val; m_gridDim = dim3(1, 1, 1); m_blockDim = dim3(block_size_x, 1, 1); } - } } diff --git a/src/nnfusion/core/operators/generic_op/generic_op_define/ScatterND.cpp b/src/nnfusion/core/operators/generic_op/generic_op_define/ScatterND.cpp index 6174bc824..8a2fcc44e 100644 --- a/src/nnfusion/core/operators/generic_op/generic_op_define/ScatterND.cpp +++ b/src/nnfusion/core/operators/generic_op/generic_op_define/ScatterND.cpp @@ -34,7 +34,7 @@ REGISTER_OP(ScatterND) { auto temp = batch_dims; temp.push_back(to_string(i)); - output_layout.push_back("input1" + + output_layout.push_back("@input1@" + vector_to_string>(temp)); } else diff --git a/src/nnfusion/core/operators/generic_op/generic_op_define/Slice.cpp b/src/nnfusion/core/operators/generic_op/generic_op_define/Slice.cpp index 75f02a837..4bfc4fa03 100644 --- a/src/nnfusion/core/operators/generic_op/generic_op_define/Slice.cpp +++ b/src/nnfusion/core/operators/generic_op/generic_op_define/Slice.cpp @@ -36,10 +36,12 @@ REGISTER_OP(Slice) auto step = steps[d]; auto start = starts[d]; auto end = ends[d]; - auto range = (u_int64_t)ceil((double)(end-start)/(double)step); - input_layout.push_back((step == 1? output_layout[d] : output_layout[d] + " * " + to_string(step)) + " + " + to_string(start)); - slice_dims += (slice_dims.empty() ? "" : " , ") + output_layout[d] + - " in " + to_string(range); + auto range = (u_int64_t)ceil((double)(end - start) / (double)step); + input_layout.push_back( + (step == 1 ? output_layout[d] : output_layout[d] + " * " + to_string(step)) + + " + " + to_string(start)); + slice_dims += + (slice_dims.empty() ? "" : " , ") + output_layout[d] + " in " + to_string(range); } auto expression_code = op::create_code_from_template( diff --git a/src/nnfusion/engine/pass/extract_graph_signature.cpp b/src/nnfusion/engine/pass/extract_graph_signature.cpp index ce0e200bc..f537121f8 100644 --- a/src/nnfusion/engine/pass/extract_graph_signature.cpp +++ b/src/nnfusion/engine/pass/extract_graph_signature.cpp @@ -142,7 +142,8 @@ bool ExtractGraphSignature::extract_args(std::shared_ptr ctx const element::Type& et = tv->get_element_type(); string type; - if(!element::Type::nnfusion_element_type_to_dtype_string(tv->get_element_type(), type)){ + if (!element::Type::nnfusion_element_type_to_dtype_string(tv->get_element_type(), type)) + { NNFUSION_LOG(ERROR) << "Get element type failed"; return false; } @@ -188,7 +189,8 @@ bool ExtractGraphSignature::extract_output(std::shared_ptr c tu->out.push_back(tv); string type; - if(!element::Type::nnfusion_element_type_to_dtype_string(tv->get_element_type(), type)){ + if (!element::Type::nnfusion_element_type_to_dtype_string(tv->get_element_type(), type)) + { NNFUSION_LOG(ERROR) << "Get element type failed"; return false; } diff --git a/src/nnfusion/frontend/onnx_import/op/pad.hpp b/src/nnfusion/frontend/onnx_import/op/pad.hpp index a89a09133..7c65adb38 100644 --- a/src/nnfusion/frontend/onnx_import/op/pad.hpp +++ b/src/nnfusion/frontend/onnx_import/op/pad.hpp @@ -22,16 +22,17 @@ namespace nnfusion /* * since opset 11, 'pads' and 'value' have been moved from attributes to inputs * */ - if (node_proto.attribute_size() == 1){ - + if (node_proto.attribute_size() == 1) + { auto input_gnode = GetInputNode(all_ng_nodes, node_proto, 0); auto padding_gnode = GetInputNode(all_ng_nodes, node_proto, 1); std::vector paddings; bool status = GetValueFromNGraphOp(padding_gnode, &paddings); NNFUSION_CHECK(status); - NNFUSION_CHECK(paddings.size() % 2 == 0) - << "Constant node for paddings does not have an even number of elements"; + NNFUSION_CHECK(paddings.size() % 2 == 0) << "Constant node for paddings " + "does not have an even number " + "of elements"; nnfusion::Shape padding_below(paddings.size() / 2); nnfusion::Shape padding_above(paddings.size() / 2); @@ -48,10 +49,11 @@ namespace nnfusion std::make_shared(input_gnode->get_element_type(), nnfusion::Shape{}, std::vector{"0"}); - auto pad_val_gnode = m_graph->add_node_and_edge(pad_val_op, GNodeVector({})); + auto pad_val_gnode = + m_graph->add_node_and_edge(pad_val_op, GNodeVector({})); - auto pad_op = - std::make_shared(padding_below, padding_above, padding_interior); + auto pad_op = std::make_shared( + padding_below, padding_above, padding_interior); pad_op->set_name(node_proto.output(0)); auto pad_gnode = @@ -59,7 +61,9 @@ namespace nnfusion NamedNodeVector ret{{node_proto.output(0), pad_gnode}}; return ret; - }else{ + } + else + { cout << "meet pad op" << endl; /* for pad op, 0: mode, 1: pads, 2: constant * we can use attr.name() to get the name of the attr @@ -70,7 +74,8 @@ namespace nnfusion auto input_gnode = GetInputNode(all_ng_nodes, node_proto, 0); const onnx::AttributeProto& modeAttr = node_proto.attribute(0); cout << modeAttr.name() << endl; - if(modeAttr.s() != "constant") NNFUSION_CHECK_FAIL() << "unsupported padding type: " << modeAttr.s(); + if (modeAttr.s() != "constant") + NNFUSION_CHECK_FAIL() << "unsupported padding type: " << modeAttr.s(); const onnx::AttributeProto& padAttr = node_proto.attribute(1); cout << padAttr.name() << endl; for (int i = 0; i < 8; ++i) @@ -85,7 +90,8 @@ namespace nnfusion std::make_shared(input_gnode->get_element_type(), nnfusion::Shape{}, std::vector{"0"}); - auto pad_val_gnode = m_graph->add_node_and_edge(pad_val_op, GNodeVector({})); + auto pad_val_gnode = + m_graph->add_node_and_edge(pad_val_op, GNodeVector({})); nnfusion::Shape padding_below(4); nnfusion::Shape padding_above(4); nnfusion::Shape padding_interior(4); @@ -95,11 +101,11 @@ namespace nnfusion for (int i = 0; i < 4; ++i) { padding_below[i] = padAttr.ints(i); - padding_above[i] = padAttr.ints(i+4); + padding_above[i] = padAttr.ints(i + 4); } - auto pad_op = - std::make_shared(padding_below, padding_above, padding_interior); + auto pad_op = std::make_shared( + padding_below, padding_above, padding_interior); pad_op->set_name(node_proto.output(0)); auto pad_gnode = diff --git a/src/nnfusion/frontend/onnx_import/op/where.cpp b/src/nnfusion/frontend/onnx_import/op/where.cpp index 2d184858c..d344e1d53 100644 --- a/src/nnfusion/frontend/onnx_import/op/where.cpp +++ b/src/nnfusion/frontend/onnx_import/op/where.cpp @@ -20,6 +20,8 @@ //---------------------------------------------------------------------------------------------- #include "where.hpp" +#include "nnfusion/core/graph/util/autobroadcast.hpp" +#include "nnfusion/core/graph/util/numpy_transpose.hpp" #include "nnfusion/core/operators/generic_op/generic_op.hpp" namespace nnfusion @@ -39,6 +41,9 @@ namespace nnfusion auto x_gnode = input_indices[1]; auto y_gnode = input_indices[2]; + std::tie(x_gnode, y_gnode) = + graph::numpy_broadcast(std::make_pair(x_gnode, y_gnode), m_graph); + auto node_name = node_proto.output(0); nnfusion::op::OpConfig::any op_config; diff --git a/src/nnfusion/frontend/onnx_import/util/util.hpp b/src/nnfusion/frontend/onnx_import/util/util.hpp index 21c82e078..fb833c146 100644 --- a/src/nnfusion/frontend/onnx_import/util/util.hpp +++ b/src/nnfusion/frontend/onnx_import/util/util.hpp @@ -143,13 +143,15 @@ namespace nnfusion } template <> - inline std::vector get_data(const onnx::TensorProto& tensor){ - + inline std::vector get_data(const onnx::TensorProto& tensor) + { if (tensor.has_raw_data()) { - return __get_raw_data(tensor.raw_data()); - }else{ - NNFUSION_LOG(NNFUSION_WARNING) << "Have no raw data" << endl ; + return __get_raw_data(tensor.raw_data()); + } + else + { + NNFUSION_LOG(NNFUSION_WARNING) << "Have no raw data" << endl; } if (tensor.data_type() == onnx::TensorProto_DataType_FLOAT16) @@ -191,9 +193,8 @@ namespace nnfusion NNFUSION_CHECK_FAIL() << "invalid data type: " << onnx::TensorProto_DataType_Name( - static_cast(tensor.data_type())); + static_cast(tensor.data_type())); return std::vector(); - } template <>