From 556158b6c230d6b402cb44c2e285db4a239493ae Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 11 Apr 2023 04:23:35 -0800 Subject: [PATCH 1/7] add support for int16_t load (bloom fp16 model) --- src/nnfusion/core/kernels/cuda_gpu/cuda_langunit.cpp | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/nnfusion/core/kernels/cuda_gpu/cuda_langunit.cpp b/src/nnfusion/core/kernels/cuda_gpu/cuda_langunit.cpp index e9711661a..9eb570935 100644 --- a/src/nnfusion/core/kernels/cuda_gpu/cuda_langunit.cpp +++ b/src/nnfusion/core/kernels/cuda_gpu/cuda_langunit.cpp @@ -295,6 +295,15 @@ __device__ __forceinline__ half load(const half* __restrict__ in, int i=0, boo } return v; } +__device__ __forceinline__ int16_t load(const int16_t* __restrict__ in, int i=0, bool b=true) +{ + int16_t v = 0; + if (b) + { + v = __ldg(in + i); + } + return v; +} __device__ __forceinline__ int32_t load(const int32_t* __restrict__ in, int i=0, bool b=true) { int32_t v = 0; From 4cb6ce3f070fe4b114126daf8a48e70841b10a97 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 11 Apr 2023 04:23:57 -0800 Subject: [PATCH 2/7] fix bugs of register fusion pass --- src/nnfusion/engine/pass/graph/register_fusion_pass.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nnfusion/engine/pass/graph/register_fusion_pass.cpp b/src/nnfusion/engine/pass/graph/register_fusion_pass.cpp index 90399dd96..6781d5702 100644 --- a/src/nnfusion/engine/pass/graph/register_fusion_pass.cpp +++ b/src/nnfusion/engine/pass/graph/register_fusion_pass.cpp @@ -417,7 +417,7 @@ class ApplyFusionResult auto out_node = out_edge->get_dst(); if (node_set.count(out_node)) continue; - m_graph->add_edge(fused_node, out_id, out_node, out_edge->get_dst_input()); + m_graph->add_edge(fused_node, i, out_node, out_edge->get_dst_input()); } } // cleanup From b62ce8400acc5597bc5816ebf5bf9b94979d2bf3 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 11 Apr 2023 22:36:29 -0800 Subject: [PATCH 3/7] re-type the CUDA_ARCH String --- src/nnfusion/engine/pass/codegen/cuda_codegen_pass.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nnfusion/engine/pass/codegen/cuda_codegen_pass.cpp b/src/nnfusion/engine/pass/codegen/cuda_codegen_pass.cpp index 8b5539f44..51f95c020 100644 --- a/src/nnfusion/engine/pass/codegen/cuda_codegen_pass.cpp +++ b/src/nnfusion/engine/pass/codegen/cuda_codegen_pass.cpp @@ -1480,7 +1480,7 @@ void CudaCodegenPass::create_cmake_file(std::shared_ptr ctx, cmake_minimum_required(VERSION 3.5) SET(SRC "nnfusion_rt.cu" CACHE STRING "codegen source file") SET(TARGET_NAME "nnfusion_naive_rt" CACHE STRING "codegen target name") -SET(CUDA_ARCH "-gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86" CACHE STRING "target architecture") +SET(CUDA_ARCH "-gencode=arch=compute_60,code=compute_60 -gencode=arch=compute_61,code=compute_61 -gencode=arch=compute_70,code=compute_70 -gencode=arch=compute_75,code=compute_75 -gencode=arch=compute_80,code=compute_80 -gencode=arch=compute_86,code=compute_86" CACHE STRING "target architecture") if(NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE Release) endif() From 6b681e3a4df938b08e5e41ac00bbc107f0532243 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 11 Apr 2023 23:52:19 -0800 Subject: [PATCH 4/7] bug fix .. --- src/nnfusion/core/kernels/cuda_gpu/cuda_emitter.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nnfusion/core/kernels/cuda_gpu/cuda_emitter.cpp b/src/nnfusion/core/kernels/cuda_gpu/cuda_emitter.cpp index 6a6b02ae8..fc974ffe5 100644 --- a/src/nnfusion/core/kernels/cuda_gpu/cuda_emitter.cpp +++ b/src/nnfusion/core/kernels/cuda_gpu/cuda_emitter.cpp @@ -870,7 +870,7 @@ void cuda::FusionCudaEmitter::set_launch_config() block[2].get_to(m_blockDim.z); grid[0].get_to(m_gridDim.x); grid[1].get_to(m_gridDim.y); - grid[1].get_to(m_gridDim.z); + grid[2].get_to(m_gridDim.z); } LanguageUnit_p cuda::FusionCudaEmitter::emit_function_signature() From 7b605e3726a1faaa921e2387bb14d6ffb9f4dd8b Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 15 Apr 2023 07:22:11 -0800 Subject: [PATCH 5/7] add dot permutation pass --- .../generic_op_define/memfusion_new_ops.cpp | 346 ++++++++++++++++++ src/nnfusion/engine/device/cuda.cpp | 2 + src/nnfusion/engine/device/rocm.cpp | 2 + src/nnfusion/engine/pass/graph/CMakeLists.txt | 1 + .../pass/graph/dot_permutation_pass.cpp | 246 +++++++++++++ .../pass/graph/dot_permutation_pass.hpp | 21 ++ 6 files changed, 618 insertions(+) create mode 100644 src/nnfusion/engine/pass/graph/dot_permutation_pass.cpp create mode 100644 src/nnfusion/engine/pass/graph/dot_permutation_pass.hpp diff --git a/src/nnfusion/core/operators/generic_op/generic_op_define/memfusion_new_ops.cpp b/src/nnfusion/core/operators/generic_op/generic_op_define/memfusion_new_ops.cpp index e01c325f0..fbeffa799 100644 --- a/src/nnfusion/core/operators/generic_op/generic_op_define/memfusion_new_ops.cpp +++ b/src/nnfusion/core/operators/generic_op/generic_op_define/memfusion_new_ops.cpp @@ -353,3 +353,349 @@ REGISTER_OP(HardSigmoid) return op::create_code_from_template(ir_template, op_config); }); + +REGISTER_OP(Permutate) + .attr("type", 0) + .attr("inner_i", 16) + .attr("inner_j", 16) + .infershape( + [](std::shared_ptr gnode) -> void + { + NNFUSION_CHECK(1 == gnode->get_input_size()); + gnode->set_output_type_and_shape( + 0, gnode->get_input_element_type(0), gnode->get_input_shape(0)); + }) + .translate_v2( + [](std::shared_ptr curr) -> std::string + { + // create expression `mediate0[N0, N1] = input0[N0 // 512 , N0 % 512, N1] where N0 in 16384;output0[N0, N1, N2, N3] = mediate0[(N0 * 16 + N2) // 16 * 16 + (N0 * 16 + N2) % 8 * 2 + (N1 * 16 + N3) % 16 // 8, (N1 * 16 + N3) // 16 * 16 + (N0 * 16 + N2) % 16 // 8 * 8 + (N1 * 16 + N3) % 8] where N0 in 1024, N1 in 256, N2 in 16, N3 in 16;` + auto generic_op = + std::dynamic_pointer_cast(curr->get_op_ptr()); + auto input0_shape = nnfusion::Shape(curr->get_input_shape(0)); + auto input0_type = curr->get_input_element_type(0); + NNFUSION_CHECK(input0_shape.size() == 2 || input0_shape.size() == 3) + << "Currently only support 2D or 3D input"; + int type = generic_op->localOpConfig.getRoot()["type"]; + string expression_template; + string expression_code; + if (input0_shape.size() == 2) + { + if (type == 0) + { + expression_template = + R"(@output0@[N0, N1, N2, N3] = @input0@[(N0 * 16 + N2) // 16 * 16 + (N0 * 16 + N2) % 8 * 2 + (N1 * 16 + N3) % 16 // 8, (N1 * 16 + N3) // 16 * 16 + (N0 * 16 + N2) % 16 // 8 * 8 + (N1 * 16 + N3) % 8] where N0 in @N0@, N1 in @N1@, N2 in @N2@, N3 in @N3@;)"; + } + else + { + NNFUSION_CHECK_FAIL() << "Permutate type not supported"; + } + nnfusion::json config; + config["N0"] = input0_shape[0] / + static_cast(generic_op->localOpConfig.getRoot()["inner_i"]); + config["N1"] = input0_shape[1] / + static_cast(generic_op->localOpConfig.getRoot()["inner_j"]); + config["N2"] = generic_op->localOpConfig.getRoot()["inner_i"]; + config["N3"] = generic_op->localOpConfig.getRoot()["inner_j"]; + expression_code = op::create_code_from_template(expression_template, config); + } + else if (input0_shape.size() == 3) + { + if (type == 0) + { + expression_template = + R"( mediate0[N0, N1] = @input0@[N0 // 512 , N0 % 512, N1] where N0 in @M@;@output0@[N0, N1, N2, N3] = mediate0[(N0 * 16 + N2) // 16 * 16 + (N0 * 16 + N2) % 8 * 2 + (N1 * 16 + N3) % 16 // 8, (N1 * 16 + N3) // 16 * 16 + (N0 * 16 + N2) % 16 // 8 * 8 + (N1 * 16 + N3) % 8] where N0 in @N0@, N1 in @N1@, N2 in @N2@, N3 in @N3@;)"; + } + else + { + NNFUSION_CHECK_FAIL() << "Permutate type not supported"; + } + + nnfusion::json config; + config["M"] = input0_shape[0] * input0_shape[1]; + config["N0"] = (input0_shape[0] * input0_shape[1]) / + static_cast(generic_op->localOpConfig.getRoot()["inner_i"]); + config["N1"] = input0_shape[2] / + static_cast(generic_op->localOpConfig.getRoot()["inner_j"]); + config["N2"] = generic_op->localOpConfig.getRoot()["inner_i"]; + config["N3"] = generic_op->localOpConfig.getRoot()["inner_j"]; + expression_code = op::create_code_from_template(expression_template, config); + } + return expression_code; + }); + +REGISTER_OP(BatchPermutate) + .attr("batch_dims", 2) + .attr("type", 0) + .attr("inner_i", 16) + .attr("inner_j", 16) + .infershape( + [](std::shared_ptr gnode) -> void + { + NNFUSION_CHECK(1 == gnode->get_input_size()); + gnode->set_output_type_and_shape( + 0, gnode->get_input_element_type(0), gnode->get_input_shape(0)); + }) + .translate_v2( + [](std::shared_ptr curr) -> std::string + { + // create expression `mediate0[N0, N1] = input0[N0 // 512 , N0 % 512, N1] where N0 in 16384;output0[N0, N1, N2, N3] = mediate0[(N0 * 16 + N2) // 16 * 16 + (N0 * 16 + N2) % 8 * 2 + (N1 * 16 + N3) % 16 // 8, (N1 * 16 + N3) // 16 * 16 + (N0 * 16 + N2) % 16 // 8 * 8 + (N1 * 16 + N3) % 8] where N0 in 1024, N1 in 256, N2 in 16, N3 in 16;` + auto generic_op = + std::dynamic_pointer_cast(curr->get_op_ptr()); + auto input0_shape = nnfusion::Shape(curr->get_input_shape(0)); + auto input0_type = curr->get_input_element_type(0); + NNFUSION_CHECK(input0_shape.size() == 3 || input0_shape.size() == 4) + << "Currently only support 3D or 4D input"; + int type = generic_op->localOpConfig.getRoot()["type"]; + string expression_template; + string expression_code; + if (input0_shape.size() == 3){ + if (type == 0) + { + expression_template = + R"(@output0@[B0, N0, N1, N2, N3] = @input0@[B0, (N0 * 16 + N2) // 16 * 16 + (N0 * 16 + N2) % 8 * 2 + (N1 * 16 + N3) % 16 // 8, (N1 * 16 + N3) // 16 * 16 + (N0 * 16 + N2) % 16 // 8 * 8 + (N1 * 16 + N3) % 8] where N0 in @N0@, N1 in @N1@, N2 in @N2@, N3 in @N3@;)"; + } + else if (type == 1) + { + // B[vi // 16, vj // 16, vi % 16, vj % 16] = A[vi // 8 * 8 + vi % 4 * 2 + vj % 16 // 8, vj // 16 * 16 + vi % 8 // 4 * 8 + vj % 8] + expression_template = + R"(@output0@[B0, N0, N1, N2, N3] = @input0@[B0, (N0 * 16 + N2) // 8 * 8 + (N0 * 16 + N2) % 4 * 2 + (N1 * 16 + N3) % 16 // 8, (N1 * 16 + N3) // 16 * 16 + (N0 * 16 + N2) % 8 // 4 * 8 + (N1 * 16 + N3) % 8] where N0 in @N0@, N1 in @N1@, N2 in @N2@, N3 in @N3@;)"; + } + else + { + NNFUSION_CHECK_FAIL() << "Permutate type not supported"; + } + nnfusion::json config; + config["N0"] = input0_shape[1] / + static_cast(generic_op->localOpConfig.getRoot()["inner_i"]); + config["N1"] = input0_shape[2] / + static_cast(generic_op->localOpConfig.getRoot()["inner_j"]); + config["N2"] = generic_op->localOpConfig.getRoot()["inner_i"]; + config["N3"] = generic_op->localOpConfig.getRoot()["inner_j"]; + expression_code = op::create_code_from_template(expression_template, config); + } + else if (input0_shape.size() == 4) + { + if (type == 0) + { + expression_template = + R"(@output0@[B0, B1, N0, N1, N2, N3] = @input0@[B0, B1, (N0 * 16 + N2) // 16 * 16 + (N0 * 16 + N2) % 8 * 2 + (N1 * 16 + N3) % 16 // 8, (N1 * 16 + N3) // 16 * 16 + (N0 * 16 + N2) % 16 // 8 * 8 + (N1 * 16 + N3) % 8] where N0 in @N0@, N1 in @N1@, N2 in @N2@, N3 in @N3@;)"; + } + else if (type == 1) + { + // B[vi // 16, vj // 16, vi % 16, vj % 16] = A[vi // 8 * 8 + vi % 4 * 2 + vj % 16 // 8, vj // 16 * 16 + vi % 8 // 4 * 8 + vj % 8] + expression_template = + R"(@output0@[B0, B1, N0, N1, N2, N3] = @input0@[B0, B1, (N0 * 16 + N2) // 8 * 8 + (N0 * 16 + N2) % 4 * 2 + (N1 * 16 + N3) % 16 // 8, (N1 * 16 + N3) // 16 * 16 + (N0 * 16 + N2) % 8 // 4 * 8 + (N1 * 16 + N3) % 8] where N0 in @N0@, N1 in @N1@, N2 in @N2@, N3 in @N3@;)"; + } + else + { + NNFUSION_CHECK_FAIL() << "Permutate type not supported"; + } + nnfusion::json config; + config["N0"] = input0_shape[2] / + static_cast(generic_op->localOpConfig.getRoot()["inner_i"]); + config["N1"] = input0_shape[3] / + static_cast(generic_op->localOpConfig.getRoot()["inner_j"]); + config["N2"] = generic_op->localOpConfig.getRoot()["inner_i"]; + config["N3"] = generic_op->localOpConfig.getRoot()["inner_j"]; + expression_code = op::create_code_from_template(expression_template, config); + } + + return expression_code; + }); + +REGISTER_OP(LayoutDot) + .attr("inner_i") + .attr("inner_j") + .attr("output_layout") + .infershape( + [](std::shared_ptr gnode) -> void + { + //TODO(leiwang1999):currently only support for NT Layout + NNFUSION_CHECK(2 == gnode->get_input_size()); + // input 0 shape is B, S, K, input 1 is K, N + // output sahpe is B, S, N + auto input0_shape = nnfusion::Shape(gnode->get_input_shape(0)); + auto input1_shape = nnfusion::Shape(gnode->get_input_shape(1)); + NNFUSION_CHECK(input0_shape.size() == 2 || input0_shape.size() == 3 || + input1_shape.size() == 2); + if (input0_shape.size() == 2) + { + nnfusion::Shape output_shape{input0_shape[0], input1_shape[1]}; + gnode->set_output_type_and_shape(0, gnode->get_input_element_type(0), output_shape); + } + else + { + nnfusion::Shape output_shape{input0_shape[0], input0_shape[1], input1_shape[1]}; + gnode->set_output_type_and_shape(0, gnode->get_input_element_type(0), output_shape); + } + }) + .translate_v2( + [](std::shared_ptr curr) -> std::string + { + auto generic_op = + std::dynamic_pointer_cast(curr->get_op_ptr()); + string fuse_template = + R"( temp0@A_fused_layout@ +=! @input0@@A_layout@ where M in @M@;)"; + string compute_template = + R"( @output0@[M, N] +=! temp0@A_fused_layout@ * @input1@@B_layout@; )"; + string ir_template = fuse_template + compute_template; + op::OpConfig::any op_config; + op_config["M"] = 16384; + op_config["A_fused_layout"] = "[M, K]"; + op_config["B_layout"] = "[N, K]"; + int output_layout = generic_op->localOpConfig.getRoot()["output_layout"]; + auto A_shape = curr->get_input_shape(0); + int raxis = A_shape.size() - 1; + string A_layout; + size_t stride = 16384; + for (int i = 0; i < A_shape.size(); i++) + { + if (i > 0) + A_layout += ", "; + if (i == raxis) + A_layout += "K"; + else + { + stride /= A_shape[i]; + A_layout += "M//" + to_string(stride) + "%" + to_string(A_shape[i]); + } + } + op_config["A_layout"] = "[" + A_layout + "]"; + + auto ir = op::create_code_from_template(ir_template, op_config); + + if (curr->get_output_element_type(0) == nnfusion::element::f16) + { + ir += "## @: output_layout=" + to_string(output_layout); + } + return ir; + }); + +REGISTER_OP(LayoutBMM) + .attr("adj_x", {{"b", false}}) + .attr("adj_y", {{"b", false}}) + .attr("inner_i", 16) + .attr("inner_j", 16) + .attr("output_layout", 0) + .constrait( + [](const nnfusion::op::OpConfig::any& config) -> bool + { + if (!config["adj_x"]["b"].is_boolean()) + return false; + if (!config["adj_y"]["b"].is_boolean()) + return false; + return true; + }) + .infershape( + [](std::shared_ptr gnode) -> void + { + NNFUSION_CHECK(gnode->get_input_size() == 2); + const nnfusion::Shape& input_shape_0 = gnode->get_input_shape(0); + const nnfusion::Shape& input_shape_1 = gnode->get_input_shape(1); + nnfusion::Shape output_shape_0; + + NNFUSION_CHECK(input_shape_0.size() == input_shape_1.size()); + NNFUSION_CHECK(gnode->get_input_element_type(0) == gnode->get_input_element_type(1)); + + for (int i = 0; i < input_shape_0.size() - 2; i++) + { + NNFUSION_CHECK(input_shape_0[i] == input_shape_1[i]); + output_shape_0.push_back(input_shape_0[i]); + } + + int m0 = input_shape_0[input_shape_0.size() - 2], + n0 = input_shape_0[input_shape_0.size() - 1]; + int m1 = input_shape_1[input_shape_1.size() - 2], + n1 = input_shape_1[input_shape_1.size() - 1]; + + auto generic_op = + std::dynamic_pointer_cast(gnode->get_op_ptr()); + bool trans_A = generic_op->localOpConfig.getRoot()["adj_x"]["b"]; + bool trans_B = generic_op->localOpConfig.getRoot()["adj_y"]["b"]; + + if (!trans_A && !trans_B) + NNFUSION_CHECK(m1 == n0), output_shape_0.push_back(m0), + output_shape_0.push_back(n1); + else if (!trans_A && trans_B) + NNFUSION_CHECK(n0 == n1), output_shape_0.push_back(m0), + output_shape_0.push_back(m1); + else if (trans_A && !trans_B) + NNFUSION_CHECK(m0 == m1), output_shape_0.push_back(n0), + output_shape_0.push_back(n1); + else // trans_A && trans_B + NNFUSION_CHECK(m0 == n1), output_shape_0.push_back(n0), + output_shape_0.push_back(m1); + gnode->set_output_type_and_shape(0, gnode->get_input_element_type(0), output_shape_0); + }) + .translate_v2( + [](std::shared_ptr curr) -> std::string + { + NNFUSION_CHECK(curr->get_input_size() == 2); + + const nnfusion::Shape& input_shape_0 = curr->get_input_shape(0); + const nnfusion::Shape& input_shape_1 = curr->get_input_shape(1); + nnfusion::Shape output_shape_0 = curr->get_output_shape(0); + + NNFUSION_CHECK(input_shape_0.size() == input_shape_1.size()); + NNFUSION_CHECK(curr->get_input_element_type(0) == curr->get_input_element_type(1)); + + auto generic_op = + std::dynamic_pointer_cast(curr->get_op_ptr()); + bool trans_A = generic_op->localOpConfig.getRoot()["adj_x"]["b"]; + bool trans_B = generic_op->localOpConfig.getRoot()["adj_y"]["b"]; + + auto ir_template = + R"( @output0@@output0_layout@ +=! @input0@@input0_layout@ * @input1@@input1_layout@; )"; + + std::vector output0_layout; + std::vector input0_layout; + std::vector input1_layout; + + for (size_t i = 0; i < output_shape_0.size() - 2; ++i) + { + std::string batch_dim = "B" + to_string(i); + output0_layout.push_back(batch_dim); + input0_layout.push_back(batch_dim); + input1_layout.push_back(batch_dim); + } + + output0_layout.push_back("N"); + output0_layout.push_back("M"); + + if (trans_A) + { + input0_layout.push_back("K"); + input0_layout.push_back("N"); + } + else + { + input0_layout.push_back("N"); + input0_layout.push_back("K"); + } + + if (trans_B) + { + input1_layout.push_back("M"); + input1_layout.push_back("K"); + } + else + { + input1_layout.push_back("K"); + input1_layout.push_back("M"); + } + + op::OpConfig::any op_config; + op_config["input0_layout"] = vector_to_string>(input0_layout); + op_config["input1_layout"] = vector_to_string>(input1_layout); + op_config["output0_layout"] = + vector_to_string>(output0_layout); + + auto ir = op::create_code_from_template(ir_template, op_config); + + int output_layout = generic_op->localOpConfig.getRoot()["output_layout"]; + if (curr->get_output_element_type(0) == nnfusion::element::f16) + { + ir += "## @: output_layout=" + to_string(output_layout); + } + return ir; + }); diff --git a/src/nnfusion/engine/device/cuda.cpp b/src/nnfusion/engine/device/cuda.cpp index 72f51f8cb..9783a1551 100644 --- a/src/nnfusion/engine/device/cuda.cpp +++ b/src/nnfusion/engine/device/cuda.cpp @@ -13,6 +13,7 @@ #include "nnfusion/engine/pass/graph/blockfusion_pass.hpp" #include "nnfusion/engine/pass/graph/common_subexpression_elimination_pass.hpp" #include "nnfusion/engine/pass/graph/dot_transpose_pass.hpp" +#include "nnfusion/engine/pass/graph/dot_permutation_pass.hpp" #include "nnfusion/engine/pass/graph/gemm_fusion_pass.hpp" #include "nnfusion/engine/pass/graph/gnode_device_dispatcher.hpp" #include "nnfusion/engine/pass/graph/gradient_weight_mapping_pass.hpp" @@ -72,6 +73,7 @@ CudaEngine::CudaEngine() g_passes->push_back(make_shared()); g_passes->push_back(make_shared()); g_passes->push_back(make_shared()); + g_passes->push_back(make_shared()); g_passes->push_back(make_shared()); // Kernel selection diff --git a/src/nnfusion/engine/device/rocm.cpp b/src/nnfusion/engine/device/rocm.cpp index ed91f83b7..fb0a4fb49 100644 --- a/src/nnfusion/engine/device/rocm.cpp +++ b/src/nnfusion/engine/device/rocm.cpp @@ -10,6 +10,7 @@ #include "nnfusion/engine/pass/graph/batchnorm_inference_folding_pass.hpp" #include "nnfusion/engine/pass/graph/blockfusion_pass.hpp" #include "nnfusion/engine/pass/graph/common_subexpression_elimination_pass.hpp" +#include "nnfusion/engine/pass/graph/dot_permutation_pass.hpp" #include "nnfusion/engine/pass/graph/dot_transpose_pass.hpp" #include "nnfusion/engine/pass/graph/gemm_fusion_pass.hpp" #include "nnfusion/engine/pass/graph/gnode_device_dispatcher.hpp" @@ -57,6 +58,7 @@ ROCmEngine::ROCmEngine() g_passes->push_back(make_shared()); g_passes->push_back(make_shared()); + g_passes->push_back(make_shared()); g_passes->push_back(make_shared()); // Kernel selection g_passes->push_back(make_shared()); diff --git a/src/nnfusion/engine/pass/graph/CMakeLists.txt b/src/nnfusion/engine/pass/graph/CMakeLists.txt index ba096dd34..0abddf1d0 100644 --- a/src/nnfusion/engine/pass/graph/CMakeLists.txt +++ b/src/nnfusion/engine/pass/graph/CMakeLists.txt @@ -25,6 +25,7 @@ set(SRC batchnorm_inference_folding_pass.cpp autodiff_pass.cpp dot_transpose_pass.cpp + dot_permutation_pass.cpp reduce_fusion_pass.cpp register_fusion_pass.cpp split_softmax_pass.cpp diff --git a/src/nnfusion/engine/pass/graph/dot_permutation_pass.cpp b/src/nnfusion/engine/pass/graph/dot_permutation_pass.cpp new file mode 100644 index 000000000..8b4d79c3a --- /dev/null +++ b/src/nnfusion/engine/pass/graph/dot_permutation_pass.cpp @@ -0,0 +1,246 @@ +#include "dot_permutation_pass.hpp" +#include "kernel_selection.hpp" +#include "nnfusion/core/graph/gnode.hpp" +#include "nnfusion/core/graph/graph.hpp" +#include "nnfusion/core/graph/graph_util.hpp" +#include "nnfusion/core/operators/op_define/broadcast.hpp" +#include "nnfusion/core/operators/op_define/fused.hpp" +#include "nnfusion/core/operators/op_define/reshape.hpp" +#include "nnfusion/core/operators/util/elementwise_arithmetic.hpp" +#include "nnfusion/core/kernels/cuda_gpu/cuda_emitter.hpp" +#include "nnfusion/util/util.hpp" +#include "gflags/gflags.h" + +#include + +using namespace nnfusion::graph; +using namespace nnfusion::pass::graph; +using namespace nnfusion::kernels; + +DEFINE_bool(fdot_permutation, false, "Enable Dot Permutation Pass"); +DEFINE_string(fpermutate_skiplist, "", "List of op types that skips in permutation"); + +namespace{ + + std::unordered_set skip_ops = {}; + void parse_skip_ops() + { + stringstream ss(FLAGS_fpermutate_skiplist); + while (ss.good()) + { + string substr; + getline(ss, substr, ','); + skip_ops.insert(substr); + } + } +} + +bool DotPermutationPass::run_on_graph(std::shared_ptr& graph) +{ + bool using_pass = FLAGS_fdot_permutation; + if (!using_pass) + return true; + parse_skip_ops(); + + NNFUSION_LOG(INFO) << "DotPermutationPass::run_on_graph start"; + std::vector> nodes = graph->get_nodes(); + size_t kernel_i = 16; + size_t kernel_j = 16; + size_t kernel_k = 16; + for (auto& it : nodes) + { + if (skip_ops.count(it->get_op_type())) + continue; + if (it->get_op_type() != "Dot" && it->get_op_type() != "BatchMatMul") + { + continue; + } + if (it->get_op_type() == "Dot"){ + + // find a dot node + NNFUSION_LOG(INFO) << "Find a dot node: " << it->get_id(); + // if node_shape's == 2, continue + // if (it->get_shape().size() == 2) + // continue; + auto it_op = static_pointer_cast(it->get_op_ptr()); + auto trans_a = it_op->get_transpose_A(); + auto trans_b = it_op->get_transpose_B(); + // get the input nodes + auto input_node = it->get_in_edge(0)->get_src(); + auto weight_node = it->get_in_edge(1)->get_src(); + NNFUSION_LOG(INFO) << "Input node: " << input_node->get_id(); + NNFUSION_LOG(INFO) << "Input Type: " << input_node->get_unique_name(); + NNFUSION_LOG(INFO) << "Weight node: " << weight_node->get_id(); + // if the input_node or weight_node is dot, continue + if (input_node->get_op_type() == "Dot" || weight_node->get_op_type() == "Dot") + NNFUSION_LOG(ERROR) << "Currently do not support input node or weight node is dot"; + + // create a new Permutate Node; + nnfusion::op::OpConfig::any permutateConfig; + permutateConfig["type"] = 0; + permutateConfig["inner_i"] = kernel_i; + permutateConfig["inner_j"] = kernel_k; + auto generic_op = std::make_shared( + "Permutate", "Permutate", permutateConfig); + // convert op to GNode + auto permutate_node = graph->add_node_and_edge(generic_op, {input_node}); + auto edge = it->get_in_edge(0); + graph->remove_edge(edge); + graph->add_edge(permutate_node, 0, it, 0); + // replace dot with LayoutDot + nnfusion::op::OpConfig::any layoutDotConfig; + layoutDotConfig["output_type"] = 0; + layoutDotConfig["inner_i"] = kernel_i; + layoutDotConfig["inner_j"] = kernel_j; + auto layoutDot_op = std::make_shared( + "LayoutDot", "LayoutDot", layoutDotConfig); + NNFUSION_LOG(INFO) << "Create layoutDot node"; + // add layoutDot_node behind layout_Dot_op + NNFUSION_LOG(INFO) << "permutate_node input shape is " << nnfusion::join(permutate_node->get_input_shape(0)); + NNFUSION_LOG(INFO) << "permutate_node shape is " + << nnfusion::join(permutate_node->get_shape()); + + auto layoutDot_node = graph->add_node_and_edge(layoutDot_op, {permutate_node, weight_node}); + NNFUSION_LOG(INFO) << "Replace it->output's input edge with layoutDot_node"; + for (auto& edge : it->get_out_edges()) + { + auto dst_node = edge->get_dst(); + auto dst_input = edge->get_dst_input(); + graph->remove_edge(edge); + graph->add_edge(layoutDot_node, 0, dst_node, dst_input); + } + graph->remove_node(it); + NNFUSION_LOG(INFO) << "Replace dot with layoutDot done"; + // apply layout transform into weight_node + NNFUSION_LOG(INFO) << "Apply layout transform into weight_node"; + auto weight_shape = weight_node->get_shape(); + auto weight_op = dynamic_pointer_cast(weight_node->get_op_ptr()); + // assert weight_op != nullptr + NNFUSION_CHECK_NOT_NULLPTR(weight_op); + NNFUSION_LOG(INFO) << "weight shape is " << weight_shape[0] << " " << weight_shape[1]; + // get element_type + auto element_type = weight_op->get_type(); + #define OFFSET2D(x, y, ld) ((x) * (ld) + (y)) + #define OFFSET4D(x, y, z, w, ld1, ld2, ld3) ((x) * (ld1) + (y) * (ld2) + (z) * (ld3) + (w)) + if (element_type == nnfusion::element::f16) + { + NNFUSION_LOG(INFO) << "weight_node's element_type is f16"; + // rewrite data as first transpose + // get data + half_float::half* data = (half_float::half *)weight_op->get_data_ptr(); + // create a temp storage + half_float::half* temp_data = (half_float::half*)(new char[weight_op->get_data_size()]); + // transpose + // if weight is transposed, direct assign + if (it_op->get_transpose_B()) + { + NNFUSION_LOG(INFO) << "weight_node is transposed"; + memcpy(temp_data, data, weight_op->get_data_size()); + } + else + { + NNFUSION_LOG(INFO) << "weight_node is not transposed"; + + for (int i = 0; i < weight_shape[0]; i++) + { + for (int j = 0; j < weight_shape[1]; j++) + { + temp_data[OFFSET2D(j, i, weight_shape[0])] = + data[OFFSET2D(i, j, weight_shape[1])]; + } + } + } + + // layout transform data[vi / 16, vj / 16, vi % 16, vj % 16] = temp_data[vi / 8 * 8 + vi % 4 * 2 + vj % 16 / 8, vj / 16 * 16 + vi % 8 / 4 * 8 + vj % 8 + for (int i = 0; i < weight_shape[1]; i++) + { + for (int j = 0; j < weight_shape[0]; j++) + { + data[OFFSET4D(i / 16, + j / 16, + i % 16, + j % 16, + kernel_j * weight_shape[0], + kernel_j * kernel_k, + kernel_k)] = + temp_data[OFFSET2D(i / 8 * 8 + i % 4 * 2 + j % 16 / 8, + j / 16 * 16 + i % 8 / 4 * 8 + j % 8, + weight_shape[0])]; + } + } + } + else{ + NNFUSION_LOG(ERROR) << "weight_node's element_type is not f16"; + } + } + else if (it->get_op_type() == "BatchMatMul"){ + NNFUSION_LOG(INFO) << "Find a BatchMatMul node: " << it->get_id(); + // get the input nodes + auto input_node = it->get_in_edge(0)->get_src(); + auto weight_node = it->get_in_edge(1)->get_src(); + NNFUSION_LOG(INFO) << "Input node: " << input_node->get_id(); + NNFUSION_LOG(INFO) << "Input Type: " << input_node->get_unique_name(); + NNFUSION_LOG(INFO) << "Weight node: " << weight_node->get_id(); + // get node's attr + auto generic_op = std::dynamic_pointer_cast(it->get_op_ptr()); + bool trans_A = generic_op->localOpConfig.getRoot()["adj_x"]["b"]; + bool trans_B = generic_op->localOpConfig.getRoot()["adj_y"]["b"]; + + // Currently do not support constant weight + NNFUSION_CHECK(weight_node->get_op_type() != "Constant") << "Constant weight is not supported for now"; + // Insert permutate node before BatchMatMul's input node and weight node + // Insert permutate node before BatchMatMul's input node + NNFUSION_LOG(INFO) << "Insert permutate node before BatchMatMul's input node"; + { + // permutate input + nnfusion::op::OpConfig::any permutateConfig; + permutateConfig["type"] = trans_A? 1 : 0; + permutateConfig["inner_i"] = kernel_i; + permutateConfig["inner_j"] = kernel_k; + auto generic_op = std::make_shared( + "BatchPermutate", "BatchPermutate", permutateConfig); + auto permutate_node = graph->add_node_and_edge(generic_op, {input_node}); + auto edge = it->get_in_edge(0); + graph->remove_edge(edge); + graph->add_edge(permutate_node, 0, it, 0); + } + { + // permutate weight + nnfusion::op::OpConfig::any permutateConfig; + permutateConfig["type"] = trans_B ? 1 : 0; + permutateConfig["inner_i"] = kernel_i; + permutateConfig["inner_j"] = kernel_k; + auto generic_op = std::make_shared( + "BatchPermutate", "BatchPermutate", permutateConfig); + auto permutate_node = graph->add_node_and_edge(generic_op, {weight_node}); + auto edge = it->get_in_edge(1); + graph->remove_edge(edge); + graph->add_edge(permutate_node, 0, it, 1); + } + // replace BatchMatMul with LayoutBMM + NNFUSION_LOG(INFO) << "Replace BatchMatMul with LayoutBMM"; + { + nnfusion::op::OpConfig::any layoutBMMConfig; + layoutBMMConfig["output_type"] = 0; + layoutBMMConfig["inner_i"] = kernel_i; + layoutBMMConfig["inner_j"] = kernel_j; + auto generic_op = std::make_shared( + "LayoutBMM", "LayoutBMM", layoutBMMConfig); + NNFUSION_LOG(INFO) << "Create LayoutBMM node"; + auto layoutBMM_node = graph->add_node_and_edge( + generic_op, {it->get_in_edge(0)->get_src(), it->get_in_edge(1)->get_src()}); + for (auto& edge : it->get_out_edges()) + { + auto dst_node = edge->get_dst(); + auto dst_input = edge->get_dst_input(); + graph->remove_edge(edge); + graph->add_edge(layoutBMM_node, 0, dst_node, dst_input); + } + graph->remove_node(it); + } + } + } +#undef OFFSET2D +#undef OFFSET4D + return true; +} diff --git a/src/nnfusion/engine/pass/graph/dot_permutation_pass.hpp b/src/nnfusion/engine/pass/graph/dot_permutation_pass.hpp new file mode 100644 index 000000000..934c9c551 --- /dev/null +++ b/src/nnfusion/engine/pass/graph/dot_permutation_pass.hpp @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#pragma once + +#include "graph_pass_base.hpp" + +namespace nnfusion +{ + namespace pass + { + namespace graph + { + class DotPermutationPass : public GraphPassBase + { + public: + bool run_on_graph(std::shared_ptr& graph) override; + }; + } + } +} From b03e0a97a4d011e0c70724d558fec71fe640cb01 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 15 Apr 2023 23:54:15 -0800 Subject: [PATCH 6/7] support layout of layoutdot --- .../generic_op_define/memfusion_new_ops.cpp | 34 +++++++++++++++---- .../pass/graph/dot_permutation_pass.cpp | 3 ++ 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/src/nnfusion/core/operators/generic_op/generic_op_define/memfusion_new_ops.cpp b/src/nnfusion/core/operators/generic_op/generic_op_define/memfusion_new_ops.cpp index fbeffa799..b49af6dcb 100644 --- a/src/nnfusion/core/operators/generic_op/generic_op_define/memfusion_new_ops.cpp +++ b/src/nnfusion/core/operators/generic_op/generic_op_define/memfusion_new_ops.cpp @@ -504,12 +504,19 @@ REGISTER_OP(BatchPermutate) }); REGISTER_OP(LayoutDot) + .attr("transpose_A") + .attr("transpose_B") .attr("inner_i") .attr("inner_j") .attr("output_layout") .infershape( [](std::shared_ptr gnode) -> void { + auto generic_op = + std::dynamic_pointer_cast(gnode->get_op_ptr()); + bool trans_a = generic_op->localOpConfig.getRoot()["transpose_A"]; + bool trans_b = generic_op->localOpConfig.getRoot()["transpose_B"]; + //TODO(leiwang1999):currently only support for NT Layout NNFUSION_CHECK(2 == gnode->get_input_size()); // input 0 shape is B, S, K, input 1 is K, N @@ -520,20 +527,35 @@ REGISTER_OP(LayoutDot) input1_shape.size() == 2); if (input0_shape.size() == 2) { - nnfusion::Shape output_shape{input0_shape[0], input1_shape[1]}; + nnfusion::Shape output_shape{trans_a ? input0_shape[1]: input0_shape[0], + trans_b ? input1_shape[0]: input1_shape[1] }; gnode->set_output_type_and_shape(0, gnode->get_input_element_type(0), output_shape); } - else + else if (input0_shape.size() == 3) { - nnfusion::Shape output_shape{input0_shape[0], input0_shape[1], input1_shape[1]}; + nnfusion::Shape output_shape{input0_shape[0], + trans_a ? input0_shape[2] : input0_shape[1], + trans_b ? input1_shape[0] : input1_shape[1] + }; gnode->set_output_type_and_shape(0, gnode->get_input_element_type(0), output_shape); } + // print trans_a and trans_b + NNFUSION_LOG(INFO) << "transa, b is " << trans_a << " " << trans_b; + // print input0 shape and input1 shape + NNFUSION_LOG(INFO) << "input0 shape is " << gnode->get_input_shape(0); + NNFUSION_LOG(INFO) << "input1 shape is " << gnode->get_input_shape(1); + NNFUSION_LOG(INFO) << "output shape is " << gnode->get_output_shape(0); }) .translate_v2( [](std::shared_ptr curr) -> std::string { + // todo(leiwang1999): apply correct experession. auto generic_op = std::dynamic_pointer_cast(curr->get_op_ptr()); + int output_layout = generic_op->localOpConfig.getRoot()["output_layout"]; + bool trans_a = generic_op->localOpConfig.getRoot()["transpose_A"]; + bool trans_b = generic_op->localOpConfig.getRoot()["transpose_B"]; + string fuse_template = R"( temp0@A_fused_layout@ +=! @input0@@A_layout@ where M in @M@;)"; string compute_template = @@ -541,9 +563,9 @@ REGISTER_OP(LayoutDot) string ir_template = fuse_template + compute_template; op::OpConfig::any op_config; op_config["M"] = 16384; - op_config["A_fused_layout"] = "[M, K]"; - op_config["B_layout"] = "[N, K]"; - int output_layout = generic_op->localOpConfig.getRoot()["output_layout"]; + op_config["A_fused_layout"] = trans_a? "[K, M]" : "[M, K]"; + op_config["B_layout"] = trans_b? "[N, K]" : "[K, N]"; + auto A_shape = curr->get_input_shape(0); int raxis = A_shape.size() - 1; string A_layout; diff --git a/src/nnfusion/engine/pass/graph/dot_permutation_pass.cpp b/src/nnfusion/engine/pass/graph/dot_permutation_pass.cpp index 8b4d79c3a..f1f159708 100644 --- a/src/nnfusion/engine/pass/graph/dot_permutation_pass.cpp +++ b/src/nnfusion/engine/pass/graph/dot_permutation_pass.cpp @@ -65,6 +65,7 @@ bool DotPermutationPass::run_on_graph(std::shared_ptr& g auto it_op = static_pointer_cast(it->get_op_ptr()); auto trans_a = it_op->get_transpose_A(); auto trans_b = it_op->get_transpose_B(); + NNFUSION_LOG(INFO) << "trans_a " << trans_a << " trans_b " << trans_b; // get the input nodes auto input_node = it->get_in_edge(0)->get_src(); auto weight_node = it->get_in_edge(1)->get_src(); @@ -92,6 +93,8 @@ bool DotPermutationPass::run_on_graph(std::shared_ptr& g layoutDotConfig["output_type"] = 0; layoutDotConfig["inner_i"] = kernel_i; layoutDotConfig["inner_j"] = kernel_j; + layoutDotConfig["transpose_A"] = trans_a; + layoutDotConfig["transpose_B"] = trans_b; auto layoutDot_op = std::make_shared( "LayoutDot", "LayoutDot", layoutDotConfig); NNFUSION_LOG(INFO) << "Create layoutDot node"; From bcbe7d0e8766abbd7fe80ee665ef5167a42ce813 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 29 May 2023 22:05:52 -0800 Subject: [PATCH 7/7] lowbit update --- CMakeLists.txt | 2 +- model.json | 840 ++++++++++++++++++ .../core/kernels/cuda_gpu/cuda_langunit.cpp | 5 + .../generic_op_define/GatherElements.cpp | 59 ++ .../generic_op_define/memfusion_new_ops.cpp | 86 ++ .../frontend/onnx_import/CMakeLists.txt | 2 + .../frontend/onnx_import/op/gather.cpp | 21 + .../frontend/onnx_import/op/gather.hpp | 8 +- .../frontend/onnx_import/op/quant_linear.cpp | 63 ++ .../frontend/onnx_import/op/quant_linear.hpp | 65 ++ .../frontend/onnx_import/op/reshape.cpp | 5 +- .../frontend/onnx_import/op/transpose.cpp | 7 +- .../frontend/onnx_import/ops_bridge.cpp | 9 +- .../onnx_import/util/graph_convert.cpp | 13 +- 14 files changed, 1179 insertions(+), 6 deletions(-) create mode 100644 model.json create mode 100644 src/nnfusion/core/operators/generic_op/generic_op_define/GatherElements.cpp create mode 100644 src/nnfusion/frontend/onnx_import/op/quant_linear.cpp create mode 100644 src/nnfusion/frontend/onnx_import/op/quant_linear.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 5a91476b7..2ef825e10 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -62,7 +62,7 @@ cmake_minimum_required (VERSION 3.10) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) # STEP.3 Set compiler flags -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14") if (${WARNINGS_AS_ERRORS}) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror") diff --git a/model.json b/model.json new file mode 100644 index 000000000..2cf935cb2 --- /dev/null +++ b/model.json @@ -0,0 +1,840 @@ +[ + [ + 204, + " - einstein_v2(\" output0[N0, N1, N2] = input0[input1[N0, N1].when(input1[N0, N1] >= 0, input1[N0, N1] + const(32000).cast(input1[N0, N1].dtype())), N2]; output1[N0, N1, N2] = output0[N0, N1, N2].cast(`float32`);\", input_dict={ \"input0\" : { \"dtype\" : \"float16\", \"shape\" : [32000, 8192]} , \"input1\" : { \"dtype\" : \"int64\", \"shape\" : [16, 512]} }, extra_outputs=[\"output0\", \"output1\"]) ## @: ", + "GatherV2_Convert", + [ + [ + 0, + 0 + ], + [ + 33, + 0 + ] + ] + ], + [ + 46, + " - einstein_v2(\" output0[N0] = input0[0] where N0 in 1; \", input_dict={ \"input0\" : { \"dtype\" : \"float32\", \"shape\" : [1]} }) ## @: memcpy ", + "Reshape", + [ + [ + 5, + 0 + ] + ] + ], + [ + 199, + " - einstein_v2(\" mediate0[N0, N1, N2] = input0[0] where N0 in 16, N1 in 512, N2 in 8192; output0[N0, N1, N2] = input1[N0, N1, N2].call(`pow`, [mediate0[N0, N1, N2]]);\", input_dict={ \"input0\" : { \"dtype\" : \"float32\", \"shape\" : [1]} , \"input1\" : { \"dtype\" : \"float32\", \"shape\" : [16, 512, 8192]} }) ## @: ", + "Broadcast_Power", + [ + [ + 19, + 0 + ], + [ + 204, + 1 + ] + ] + ], + [ + 200, + " - einstein_v2(\" mediate0[N0, N1] = input0[0] where N0 in 16, N1 in 512; mediate1[N0, N1] +=! input1[N0, N1, N2];output0[N0, N1] = mediate1[N0, N1] / mediate0[N0, N1];\", input_dict={ \"input0\" : { \"dtype\" : \"float32\", \"shape\" : [1]} , \"input1\" : { \"dtype\" : \"float32\", \"shape\" : [16, 512, 8192]} }) ## @: ", + "Sum_Broadcast_Divide", + [ + [ + 42, + 0 + ], + [ + 199, + 0 + ] + ] + ], + [ + 205, + " - einstein_v2(\" mediate0[N0, N1, N2] = input0[N2] where N0 in 16, N1 in 512; mediate1[N0, N1, N2] = input1[N2] where N0 in 16, N1 in 512; mediate2[N0, N1, N2] = input2[N0, N1] where N2 in 1; mediate3[N0, N1, N2] = mediate2[N0, N1, N2] + mediate1[N0, N1, N2]; mediate4[N0, N1, N2] = mediate3[N0, N1, N2].call(`sqrt`); mediate5[N0, N1] = mediate4[N0, N1, 0] ; mediate6[N0, N1, N2] = mediate5[N0, N1] where N2 in 8192; mediate7[N0, N1, N2] = input3[N0, N1, N2] / mediate6[N0, N1, N2];mediate8[N0, N1, N2] = mediate7[N0, N1, N2].cast(`float16`);output0[N0, N1, N2] = mediate0[N0, N1, N2] * mediate8[N0, N1, N2];\", input_dict={ \"input0\" : { \"dtype\" : \"float16\", \"shape\" : [8192]} , \"input1\" : { \"dtype\" : \"float32\", \"shape\" : [1]} , \"input2\" : { \"dtype\" : \"float32\", \"shape\" : [16, 512]} , \"input3\" : { \"dtype\" : \"float32\", \"shape\" : [16, 512, 8192]} }) ## @: ", + "Reshape_Broadcast_Add_Sqrt_Reshape_Broadcast_Divide_Convert_Broadcast_Multiply", + [ + [ + 1, + 0 + ], + [ + 46, + 0 + ], + [ + 200, + 0 + ], + [ + 204, + 1 + ] + ] + ], + [ + 177, + " - einstein_v2(\" mediate0[N0, N1] = input0[N0 // 512 , N0 % 512, N1] where N0 in 8192;output0[N0, N1, N2, N3] = mediate0[(N0 * 16 + N2) // 16 * 16 + (N0 * 16 + N2) % 8 * 2 + (N1 * 16 + N3) % 16 // 8, (N1 * 16 + N3) // 16 * 16 + (N0 * 16 + N2) % 16 // 8 * 8 + (N1 * 16 + N3) % 8] where N0 in 512, N1 in 512, N2 in 16, N3 in 16;\", input_dict={ \"input0\" : { \"dtype\" : \"float16\", \"shape\" : [16, 512, 8192]} }) ", + "Permutate", + [ + [ + 205, + 0 + ] + ] + ], + [ + 178, + " - einstein_v2(\" temp0[M, K] +=! input0[M//1024%16, M//2%512, K] where M in 16384; output0[M, N] +=! temp0[M, K] * input1[K, N]; \", input_dict={ \"input0\" : { \"dtype\" : \"float16\", \"shape\" : [16, 512, 8192]} , \"input1\" : { \"dtype\" : \"float16\", \"shape\" : [8192, 8192]} }) ## @: output_layout=0 |skip", + "LayoutDot", + [ + [ + 177, + 0 + ], + [ + 13, + 0 + ] + ] + ], + [ + 203, + " - einstein_v2(\" mediate0[N0, N1, N2, N3] = input0[N0, N1, ((N2) * 128 + N3)] where N2 in 64, N3 in 128; output0[N0, N2, N1, N3] = mediate0[N0, N1, N2, N3] ; mediate1[N0, N1, N2, N3] = output0[N0, N1, N2, N3] ; output1[N0, N1, N2, N3] = mediate1[N0, N1, N2, N3] ; \", input_dict={ \"input0\" : { \"dtype\" : \"float16\", \"shape\" : [16, 512, 8192]} }, extra_outputs=[\"output0\", \"output1\"]) ## @: ", + "Reshape_Reshape_Reshape_Broadcast", + [ + [ + 178, + 0 + ] + ] + ], + [ + 168, + "", + "Result", + [ + [ + 203, + 0 + ] + ] + ], + [ + 65, + " - einstein_v2(\" output0[N0, N1, N2, N3] = input0[N0 % 1, N1 % 1, N2 % 512, N3 % 128] where N0 in 1, N1 in 1, N2 in 512, N3 in 128; \", input_dict={ \"input0\" : { \"dtype\" : \"float16\", \"shape\" : [1, 1, 512, 128]} }) ", + "Tile", + [ + [ + 15, + 0 + ], + [ + 18, + 0 + ] + ] + ], + [ + 66, + " - einstein_v2(\" output0[N0, N1, N2, N3] = input0[N0, N1, input1[N0, N1, N2, N3].when(input1[N0, N1, N2, N3] >= 0, input1[N0, N1, N2, N3] + const(512).cast(input1[N0, N1, N2, N3].dtype())), N3]; \", input_dict={ \"input0\" : { \"dtype\" : \"float16\", \"shape\" : [1, 1, 512, 128]} , \"input1\" : { \"dtype\" : \"int64\", \"shape\" : [1, 1, 512, 128]} }) ", + "GatherElements", + [ + [ + 65, + 0 + ], + [ + 28, + 0 + ] + ] + ], + [ + 173, + " - einstein_v2(\" mediate0[N0, N1] = input0[N0 // 512 , N0 % 512, N1] where N0 in 8192;output0[N0, N1, N2, N3] = mediate0[(N0 * 16 + N2) // 16 * 16 + (N0 * 16 + N2) % 8 * 2 + (N1 * 16 + N3) % 16 // 8, (N1 * 16 + N3) // 16 * 16 + (N0 * 16 + N2) % 16 // 8 * 8 + (N1 * 16 + N3) % 8] where N0 in 512, N1 in 512, N2 in 16, N3 in 16;\", input_dict={ \"input0\" : { \"dtype\" : \"float16\", \"shape\" : [16, 512, 8192]} }) ", + "Permutate", + [ + [ + 205, + 0 + ] + ] + ], + [ + 174, + " - einstein_v2(\" temp0[M, K] +=! input0[M//1024%16, M//2%512, K] where M in 16384; output0[M, N] +=! temp0[M, K] * input1[K, N]; \", input_dict={ \"input0\" : { \"dtype\" : \"float16\", \"shape\" : [16, 512, 8192]} , \"input1\" : { \"dtype\" : \"float16\", \"shape\" : [8192, 8192]} }) ## @: output_layout=0 |skip", + "LayoutDot", + [ + [ + 173, + 0 + ], + [ + 11, + 0 + ] + ] + ], + [ + 196, + " - einstein_v2(\" mediate0[N0, N2, N1, N3] = input0[N0, N1, N2, N3] ; mediate1[N0, N1] = mediate0[0, N0, 0, N1] ; mediate2[N0, N1, N2, N3] = mediate1[N1, N3] where N0 in 16, N2 in 64; output0[N0, N1, N2, N3] = input1[N0, N1, ((N2) * 128 + N3)] where N2 in 64, N3 in 128; output1[N0, N1, N2, N3] = output0[N0, N1, N2, N3] * mediate2[N0, N1, N2, N3];\", input_dict={ \"input0\" : { \"dtype\" : \"float16\", \"shape\" : [1, 1, 512, 128]} , \"input1\" : { \"dtype\" : \"float16\", \"shape\" : [16, 512, 8192]} }, extra_outputs=[\"output0\", \"output1\"]) ## @: ", + "Reshape_Reshape_Reshape_Broadcast_Multiply", + [ + [ + 66, + 0 + ], + [ + 174, + 0 + ] + ] + ], + [ + 34, + " - einstein_v2(\" output0[N0, N1, N2, N3] = input0[N0 % 1, N1 % 1, N2 % 512, N3 % 128] where N0 in 1, N1 in 1, N2 in 512, N3 in 128; \", input_dict={ \"input0\" : { \"dtype\" : \"float16\", \"shape\" : [1, 1, 512, 128]} }) ", + "Tile", + [ + [ + 4, + 0 + ], + [ + 20, + 0 + ] + ] + ], + [ + 35, + " - einstein_v2(\" output0[N0, N1, N2, N3] = input0[N0, N1, input1[N0, N1, N2, N3].when(input1[N0, N1, N2, N3] >= 0, input1[N0, N1, N2, N3] + const(512).cast(input1[N0, N1, N2, N3].dtype())), N3]; \", input_dict={ \"input0\" : { \"dtype\" : \"float16\", \"shape\" : [1, 1, 512, 128]} , \"input1\" : { \"dtype\" : \"int64\", \"shape\" : [1, 1, 512, 128]} }) ", + "GatherElements", + [ + [ + 34, + 0 + ], + [ + 28, + 0 + ] + ] + ], + [ + 60, + " - einstein_v2(\" output0[N0, N1, N2, N3] = input0[N0 + 0, N1 + 0, N2 + 0, N3 + 0] where N0 in 16 , N1 in 512 , N2 in 64 , N3 in 64; \", input_dict={ \"input0\" : { \"dtype\" : \"float16\", \"shape\" : [16, 512, 64, 128]} }) ", + "Slice", + [ + [ + 196, + 0 + ] + ] + ], + [ + 201, + " - einstein_v2(\" mediate0[N0, N1, N2, N3] = input0[N0 + 0, N1 + 0, N2 + 0, N3 + 64] where N0 in 16 , N1 in 512 , N2 in 64 , N3 in 64; output0[N0, N1, N2, N3] = -mediate0[N0, N1, N2, N3];\", input_dict={ \"input0\" : { \"dtype\" : \"float16\", \"shape\" : [16, 512, 64, 128]} }) ## @: ", + "Slice_Negative", + [ + [ + 196, + 0 + ] + ] + ], + [ + 208, + " - einstein_v2(\" mediate0[N0, N2, N1, N3] = input0[N0, N1, N2, N3] ; mediate1[N0, N1] = mediate0[0, N0, 0, N1] ; mediate2[N0, N1, N2, N3] = mediate1[N1, N3] where N0 in 16, N2 in 64; mediate3[N0, N1, N2, N3] = input1[N0, N1, N2, N3 - 0].when(N3 < 64, input2[N0, N1, N2, N3 - 64]) where N3 in 128; output0[N0, N1, N2, N3] = mediate3[N0, N1, N2, N3] * mediate2[N0, N1, N2, N3];\", input_dict={ \"input0\" : { \"dtype\" : \"float16\", \"shape\" : [1, 1, 512, 128]} , \"input1\" : { \"dtype\" : \"float16\", \"shape\" : [16, 512, 64, 64]} , \"input2\" : { \"dtype\" : \"float16\", \"shape\" : [16, 512, 64, 64]} }) ## @: ", + "Reshape_Concat_Reshape_Broadcast_Multiply", + [ + [ + 35, + 0 + ], + [ + 201, + 0 + ], + [ + 60, + 0 + ] + ] + ], + [ + 71, + " - einstein_v2(\" output0[N0, N1, N2, N3] = input0[N0, N1, N2, N3] + input1[N0, N1, N2, N3]; \", input_dict={ \"input0\" : { \"dtype\" : \"float16\", \"shape\" : [16, 512, 64, 128]} , \"input1\" : { \"dtype\" : \"float16\", \"shape\" : [16, 512, 64, 128]} }) ", + "Add", + [ + [ + 196, + 1 + ], + [ + 208, + 0 + ] + ] + ], + [ + 72, + " - einstein_v2(\" output0[N0, N2, N1, N3] = input0[N0, N1, N2, N3] ; \", input_dict={ \"input0\" : { \"dtype\" : \"float16\", \"shape\" : [16, 512, 64, 128]} }) ", + "Reshape", + [ + [ + 71, + 0 + ] + ] + ], + [ + 167, + "", + "Result", + [ + [ + 72, + 0 + ] + ] + ], + [ + 107, + " - einstein_v2(\" output0[N0, N1, N2, N3] = input0[0] where N0 in 16, N1 in 64, N2 in 512, N3 in 512; \", input_dict={ \"input0\" : { \"dtype\" : \"float16\", \"shape\" : [1]} }) ", + "Broadcast", + [ + [ + 31, + 0 + ] + ] + ], + [ + 211, + " - einstein_v2(\" mediate0[N0] = input0[0] where N0 in 1; mediate1[N0, N1, N2, N3] = mediate0[N1] where N0 in 16, N2 in 512, N3 in 512; output0[N0, N1, N2, N3] = mediate1[N0, N1, N2, N3] - input1[N0, N1, N2, N3];output1[N0, N1, N2, N3] = (output0[N0, N1, N2, N3] != 0).cast(`int16`);\", input_dict={ \"input0\" : { \"dtype\" : \"float16\", \"shape\" : [1]} , \"input1\" : { \"dtype\" : \"float16\", \"shape\" : [16, 1, 512, 512]} }, extra_outputs=[\"output0\", \"output1\"]) ## @: ", + "Reshape_Broadcast_Subtract_Convert", + [ + [ + 30, + 0 + ], + [ + 16, + 0 + ] + ] + ], + [ + 210, + " - einstein_v2(\" mediate0[N0] = input0[0] where N0 in 1; output0[N0, N1, N2, N3] = mediate0[N1] where N0 in 16, N2 in 512, N3 in 512; \", input_dict={ \"input0\" : { \"dtype\" : \"float16\", \"shape\" : [1]} }) ## @: ", + "Reshape_Broadcast", + [ + [ + 31, + 0 + ] + ] + ], + [ + 202, + " - einstein_v2(\"mediate0[N0, N1, N2, N3] = input2[N0, N1, N2, N3].when([input0[N0, N1, N2, N3] == 0], input1[N0, N1, N2, N3]); output0[N0, N1, N2, N3] = mediate0[N0, N1, N2, N3] + input3[N0, N1, N2, N3]; \", input_dict={ \"input0\" : { \"dtype\" : \"int16\", \"shape\" : [16, 1, 512, 512]} , \"input1\" : { \"dtype\" : \"float16\", \"shape\" : [16, 1, 512, 512]} , \"input2\" : { \"dtype\" : \"float16\", \"shape\" : [16, 1, 512, 512]} , \"input3\" : { \"dtype\" : \"float16\", \"shape\" : [16, 1, 512, 512]} }) ## @: ", + "Select_Add", + [ + [ + 211, + 1 + ], + [ + 210, + 0 + ], + [ + 211, + 0 + ], + [ + 17, + 0 + ] + ] + ], + [ + 175, + " - einstein_v2(\" mediate0[N0, N1] = input0[N0 // 512 , N0 % 512, N1] where N0 in 8192;output0[N0, N1, N2, N3] = mediate0[(N0 * 16 + N2) // 16 * 16 + (N0 * 16 + N2) % 8 * 2 + (N1 * 16 + N3) % 16 // 8, (N1 * 16 + N3) // 16 * 16 + (N0 * 16 + N2) % 16 // 8 * 8 + (N1 * 16 + N3) % 8] where N0 in 512, N1 in 512, N2 in 16, N3 in 16;\", input_dict={ \"input0\" : { \"dtype\" : \"float16\", \"shape\" : [16, 512, 8192]} }) ", + "Permutate", + [ + [ + 205, + 0 + ] + ] + ], + [ + 176, + " - einstein_v2(\" temp0[M, K] +=! input0[M//1024%16, M//2%512, K] where M in 16384; output0[M, N] +=! temp0[M, K] * input1[K, N]; \", input_dict={ \"input0\" : { \"dtype\" : \"float16\", \"shape\" : [16, 512, 8192]} , \"input1\" : { \"dtype\" : \"float16\", \"shape\" : [8192, 8192]} }) ## @: output_layout=0 |skip", + "LayoutDot", + [ + [ + 175, + 0 + ], + [ + 9, + 0 + ] + ] + ], + [ + 193, + " - einstein_v2(\" mediate0[N0, N1, N2, N3] = input0[N0, N1, ((N2) * 128 + N3)] where N2 in 64, N3 in 128; output0[N0, N2, N1, N3] = mediate0[N0, N1, N2, N3] ; mediate1[N0, N1] = input1[0, 0, N0, N1] ; mediate2[N0, N1, N2, N3] = mediate1[N2, N3] where N0 in 16, N1 in 64; output1[N0, N1, N2, N3] = output0[N0, N1, N2, N3] * mediate2[N0, N1, N2, N3];\", input_dict={ \"input0\" : { \"dtype\" : \"float16\", \"shape\" : [16, 512, 8192]} , \"input1\" : { \"dtype\" : \"float16\", \"shape\" : [1, 1, 512, 128]} }, extra_outputs=[\"output0\", \"output1\"]) ## @: ", + "Reshape_Reshape_Reshape_Broadcast_Multiply", + [ + [ + 176, + 0 + ], + [ + 66, + 0 + ] + ] + ], + [ + 79, + " - einstein_v2(\" output0[N0, N1, N2, N3] = input0[N0 + 0, N1 + 0, N2 + 0, N3 + 0] where N0 in 16 , N1 in 64 , N2 in 512 , N3 in 64; \", input_dict={ \"input0\" : { \"dtype\" : \"float16\", \"shape\" : [16, 64, 512, 128]} }) ", + "Slice", + [ + [ + 193, + 0 + ] + ] + ], + [ + 198, + " - einstein_v2(\" mediate0[N0, N1, N2, N3] = input0[N0 + 0, N1 + 0, N2 + 0, N3 + 64] where N0 in 16 , N1 in 64 , N2 in 512 , N3 in 64; output0[N0, N1, N2, N3] = -mediate0[N0, N1, N2, N3];\", input_dict={ \"input0\" : { \"dtype\" : \"float16\", \"shape\" : [16, 64, 512, 128]} }) ## @: ", + "Slice_Negative", + [ + [ + 193, + 0 + ] + ] + ], + [ + 192, + " - einstein_v2(\" mediate0[N0, N1] = input0[0, 0, N0, N1] ; mediate1[N0, N1, N2, N3] = mediate0[N2, N3] where N0 in 16, N1 in 64; mediate2[N0, N1, N2, N3] = input1[N0, N1, N2, N3 - 0].when(N3 < 64, input2[N0, N1, N2, N3 - 64]) where N3 in 128; output0[N0, N1, N2, N3] = mediate2[N0, N1, N2, N3] * mediate1[N0, N1, N2, N3];\", input_dict={ \"input0\" : { \"dtype\" : \"float16\", \"shape\" : [1, 1, 512, 128]} , \"input1\" : { \"dtype\" : \"float16\", \"shape\" : [16, 64, 512, 64]} , \"input2\" : { \"dtype\" : \"float16\", \"shape\" : [16, 64, 512, 64]} }) ## @: ", + "Concat_Reshape_Broadcast_Multiply", + [ + [ + 35, + 0 + ], + [ + 198, + 0 + ], + [ + 79, + 0 + ] + ] + ], + [ + 197, + " - einstein_v2(\" mediate0[N0, N1, N2, N3] = input0[N0, N1, N2, N3] + input1[N0, N1, N2, N3]; mediate1[N0, N1, N2, N3] = mediate0[N0, N1, N2, N3] ; output0[N0, N1, N2, N3] = mediate1[N0, N1, N2, N3] ; \", input_dict={ \"input0\" : { \"dtype\" : \"float16\", \"shape\" : [16, 64, 512, 128]} , \"input1\" : { \"dtype\" : \"float16\", \"shape\" : [16, 64, 512, 128]} }) ## @: ", + "Add_Reshape_Broadcast", + [ + [ + 193, + 1 + ], + [ + 192, + 0 + ] + ] + ], + [ + 212, + " - einstein_v2(\" mediate0[N0, N2, N3, N1] = input0[N0, N1, N2, N3] ; mediate1[N0, N1, N2, N3] = mediate0[N0, N1, N2, N3] ; output0[N0, N1, N2, N3] = mediate1[N0, N1, N2, N3] ; \", input_dict={ \"input0\" : { \"dtype\" : \"float16\", \"shape\" : [16, 512, 64, 128]} }) ## @: ", + "Reshape_Reshape_Broadcast", + [ + [ + 71, + 0 + ] + ] + ], + [ + 92, + " - einstein_v2(\" output0[B0, B1, N, M] +=! input0[B0, B1, N, K] * input1[B0, B1, K, M]; \", input_dict={ \"input0\" : { \"dtype\" : \"float16\", \"shape\" : [16, 64, 512, 128]} , \"input1\" : { \"dtype\" : \"float16\", \"shape\" : [16, 64, 128, 512]} }) ## @: tensorCoreConfig=(2, 3) |skip", + "BatchMatMul", + [ + [ + 197, + 0 + ], + [ + 212, + 0 + ] + ] + ], + [ + 195, + " - einstein_v2(\" mediate0[N0, N1, N2] = input0[N0, 0, N1, N2] ; mediate1[N0, N1, N2, N3] = mediate0[N0, N2, N3] where N1 in 64; mediate2[N0, N1, N2, N3] = input1[0] where N0 in 16, N1 in 64, N2 in 512, N3 in 512; mediate3[N0, N1, N2, N3] = input2[N0, N1, N2, N3] ; mediate4[N0, N1, N2, N3] = mediate3[N0, N1, N2, N3] / mediate2[N0, N1, N2, N3]; mediate5[N0, N1, N2, N3] = mediate4[N0, N1, N2, N3] + mediate1[N0, N1, N2, N3]; output0[N0, N1, N2, N3] = mediate5[N0, N1, N2, N3].call(`max`, [input3[N0, N1, N2, N3]]);\", input_dict={ \"input0\" : { \"dtype\" : \"float16\", \"shape\" : [16, 1, 512, 512]} , \"input1\" : { \"dtype\" : \"float16\", \"shape\" : [1]} , \"input2\" : { \"dtype\" : \"float16\", \"shape\" : [16, 64, 512, 512]} , \"input3\" : { \"dtype\" : \"float16\", \"shape\" : [16, 64, 512, 512]} }) ## @: ", + "Reshape_Broadcast_Divide_Reshape_Broadcast_Add_Maximum", + [ + [ + 202, + 0 + ], + [ + 21, + 0 + ], + [ + 92, + 0 + ], + [ + 107, + 0 + ] + ] + ], + [ + 169, + " - einstein_v2(\" output0[N0, N1, N2] >=! input0[N0, N1, N2, N3]; \", input_dict={ \"input0\" : { \"dtype\" : \"float16\", \"shape\" : [16, 64, 512, 512]} }) ", + "SoftmaxBasic", + [ + [ + 195, + 0 + ] + ] + ], + [ + 170, + " - einstein_v2(\" output0[N0, N1, N2, N3] = (input0[N0, N1, N2, N3] - input1[N0, N1, N2]).call(`exp`); \", input_dict={ \"input0\" : { \"dtype\" : \"float16\", \"shape\" : [16, 64, 512, 512]} , \"input1\" : { \"dtype\" : \"float16\", \"shape\" : [16, 64, 512]} }) ", + "SoftmaxBasic", + [ + [ + 195, + 0 + ], + [ + 169, + 0 + ] + ] + ], + [ + 171, + " - einstein_v2(\" output0[N0, N1, N2] +=! input0[N0, N1, N2, N3]; \", input_dict={ \"input0\" : { \"dtype\" : \"float16\", \"shape\" : [16, 64, 512, 512]} }) ", + "SoftmaxBasic", + [ + [ + 170, + 0 + ] + ] + ], + [ + 191, + " - einstein_v2(\" mediate0[N0, N1, N2, N3] = input0[N0, N1, N2, N3] / input1[N0, N1, N2]; mediate1[N0, N1, N2, N3] = mediate0[N0, N1, N2, N3] ; output0[N0, N1, N2, N3] = mediate1[N0, N1, N2, N3] ; \", input_dict={ \"input0\" : { \"dtype\" : \"float16\", \"shape\" : [16, 64, 512, 512]} , \"input1\" : { \"dtype\" : \"float16\", \"shape\" : [16, 64, 512]} }) ## @: ", + "Reshape_Broadcast_SoftmaxBasic", + [ + [ + 170, + 0 + ], + [ + 171, + 0 + ] + ] + ], + [ + 117, + " - einstein_v2(\" output0[B0, B1, N, M] +=! input0[B0, B1, N, K] * input1[B0, B1, K, M]; \", input_dict={ \"input0\" : { \"dtype\" : \"float16\", \"shape\" : [16, 64, 512, 512]} , \"input1\" : { \"dtype\" : \"float16\", \"shape\" : [16, 64, 512, 128]} }) ## @: tensorCoreConfig=(2, 3) |skip", + "BatchMatMul", + [ + [ + 191, + 0 + ], + [ + 203, + 1 + ] + ] + ], + [ + 194, + " - einstein_v2(\" mediate0[N0, N1, N2, N3] = input0[N0, N1, N2, N3] ; mediate1[N0, N2, N1, N3] = mediate0[N0, N1, N2, N3] ; output0[N0, N1, N2] = mediate1[N0, N1, N2 / 128 % 64, N2 / 1 % 128] where N2 in 8192; \", input_dict={ \"input0\" : { \"dtype\" : \"float16\", \"shape\" : [16, 64, 512, 128]} }) ## @: ", + "Reshape_Reshape_Reshape", + [ + [ + 117, + 0 + ] + ] + ], + [ + 179, + " - einstein_v2(\" mediate0[N0, N1] = input0[N0 // 512 , N0 % 512, N1] where N0 in 8192;output0[N0, N1, N2, N3] = mediate0[(N0 * 16 + N2) // 16 * 16 + (N0 * 16 + N2) % 8 * 2 + (N1 * 16 + N3) % 16 // 8, (N1 * 16 + N3) // 16 * 16 + (N0 * 16 + N2) % 16 // 8 * 8 + (N1 * 16 + N3) % 8] where N0 in 512, N1 in 512, N2 in 16, N3 in 16;\", input_dict={ \"input0\" : { \"dtype\" : \"float16\", \"shape\" : [16, 512, 8192]} }) ", + "Permutate", + [ + [ + 194, + 0 + ] + ] + ], + [ + 180, + " - einstein_v2(\" temp0[M, K] +=! input0[M//1024%16, M//2%512, K] where M in 16384; output0[M, N] +=! temp0[M, K] * input1[K, N]; \", input_dict={ \"input0\" : { \"dtype\" : \"float16\", \"shape\" : [16, 512, 8192]} , \"input1\" : { \"dtype\" : \"float16\", \"shape\" : [8192, 8192]} }) ## @: output_layout=0 |skip", + "LayoutDot", + [ + [ + 179, + 0 + ], + [ + 23, + 0 + ] + ] + ], + [ + 206, + " - einstein_v2(\" mediate0[N0, N1, N2] = input0[0] where N0 in 16, N1 in 512, N2 in 8192; output0[N0, N1, N2] = input1[N0, N1, N2] + input2[N0, N1, N2]; output1[N0, N1, N2] = output0[N0, N1, N2].cast(`float32`);output2[N0, N1, N2] = output1[N0, N1, N2].call(`pow`, [mediate0[N0, N1, N2]]);\", input_dict={ \"input0\" : { \"dtype\" : \"float32\", \"shape\" : [1]} , \"input1\" : { \"dtype\" : \"float16\", \"shape\" : [16, 512, 8192]} , \"input2\" : { \"dtype\" : \"float16\", \"shape\" : [16, 512, 8192]} }, extra_outputs=[\"output0\", \"output1\", \"output2\"]) ## @: ", + "Add_Convert_Broadcast_Power", + [ + [ + 19, + 0 + ], + [ + 204, + 0 + ], + [ + 180, + 0 + ] + ] + ], + [ + 189, + " - einstein_v2(\" mediate0[N0, N1] = input0[0] where N0 in 16, N1 in 512; mediate1[N0, N1] +=! input1[N0, N1, N2];output0[N0, N1] = mediate1[N0, N1] / mediate0[N0, N1];\", input_dict={ \"input0\" : { \"dtype\" : \"float32\", \"shape\" : [1]} , \"input1\" : { \"dtype\" : \"float32\", \"shape\" : [16, 512, 8192]} }) ## @: ", + "Sum_Broadcast_Divide", + [ + [ + 127, + 0 + ], + [ + 206, + 2 + ] + ] + ], + [ + 190, + " - einstein_v2(\" mediate0[N0, N1, N2] = input0[N2] where N0 in 16, N1 in 512; mediate1[N0] = input1[0] where N0 in 1; mediate2[N0, N1, N2] = mediate1[N2] where N0 in 16, N1 in 512; mediate3[N0, N1, N2] = input2[N0, N1] where N2 in 1; mediate4[N0, N1, N2] = mediate3[N0, N1, N2] + mediate2[N0, N1, N2]; mediate5[N0, N1, N2] = mediate4[N0, N1, N2].call(`sqrt`); mediate6[N0, N1] = mediate5[N0, N1, 0] ; mediate7[N0, N1, N2] = mediate6[N0, N1] where N2 in 8192; mediate8[N0, N1, N2] = input3[N0, N1, N2] / mediate7[N0, N1, N2];mediate9[N0, N1, N2] = mediate8[N0, N1, N2].cast(`float16`);output0[N0, N1, N2] = mediate0[N0, N1, N2] * mediate9[N0, N1, N2];\", input_dict={ \"input0\" : { \"dtype\" : \"float16\", \"shape\" : [8192]} , \"input1\" : { \"dtype\" : \"float32\", \"shape\" : [1]} , \"input2\" : { \"dtype\" : \"float32\", \"shape\" : [16, 512]} , \"input3\" : { \"dtype\" : \"float32\", \"shape\" : [16, 512, 8192]} }) ## @: ", + "Reshape_Reshape_Broadcast_Add_Sqrt_Reshape_Broadcast_Divide_Convert_Broadcast_Multiply", + [ + [ + 2, + 0 + ], + [ + 5, + 0 + ], + [ + 189, + 0 + ], + [ + 206, + 1 + ] + ] + ], + [ + 181, + " - einstein_v2(\" mediate0[N0, N1] = input0[N0 // 512 , N0 % 512, N1] where N0 in 8192;output0[N0, N1, N2, N3] = mediate0[(N0 * 16 + N2) // 16 * 16 + (N0 * 16 + N2) % 8 * 2 + (N1 * 16 + N3) % 16 // 8, (N1 * 16 + N3) // 16 * 16 + (N0 * 16 + N2) % 16 // 8 * 8 + (N1 * 16 + N3) % 8] where N0 in 512, N1 in 512, N2 in 16, N3 in 16;\", input_dict={ \"input0\" : { \"dtype\" : \"float16\", \"shape\" : [16, 512, 8192]} }) ", + "Permutate", + [ + [ + 190, + 0 + ] + ] + ], + [ + 182, + " - einstein_v2(\" temp0[M, K] +=! input0[M//1024%16, M//2%512, K] where M in 16384; output0[M, N] +=! temp0[M, K] * input1[K, N]; \", input_dict={ \"input0\" : { \"dtype\" : \"float16\", \"shape\" : [16, 512, 8192]} , \"input1\" : { \"dtype\" : \"float16\", \"shape\" : [8192, 22016]} }) ## @: output_layout=0 |skip", + "LayoutDot", + [ + [ + 181, + 0 + ], + [ + 25, + 0 + ] + ] + ], + [ + 183, + " - einstein_v2(\" mediate0[N0, N1] = input0[N0 // 512 , N0 % 512, N1] where N0 in 8192;output0[N0, N1, N2, N3] = mediate0[(N0 * 16 + N2) // 16 * 16 + (N0 * 16 + N2) % 8 * 2 + (N1 * 16 + N3) % 16 // 8, (N1 * 16 + N3) // 16 * 16 + (N0 * 16 + N2) % 16 // 8 * 8 + (N1 * 16 + N3) % 8] where N0 in 512, N1 in 512, N2 in 16, N3 in 16;\", input_dict={ \"input0\" : { \"dtype\" : \"float16\", \"shape\" : [16, 512, 8192]} }) ", + "Permutate", + [ + [ + 190, + 0 + ] + ] + ], + [ + 184, + " - einstein_v2(\" temp0[M, K] +=! input0[M//1024%16, M//2%512, K] where M in 16384; output0[M, N] +=! temp0[M, K] * input1[K, N]; \", input_dict={ \"input0\" : { \"dtype\" : \"float16\", \"shape\" : [16, 512, 8192]} , \"input1\" : { \"dtype\" : \"float16\", \"shape\" : [8192, 22016]} }) ## @: output_layout=0 |skip", + "LayoutDot", + [ + [ + 183, + 0 + ], + [ + 24, + 0 + ] + ] + ], + [ + 188, + " - einstein_v2(\"mediate0[N0, N1, N2] = const(1).cast(input1[N0, N1, N2].dtype()) / (const(1).cast(input1[N0, N1, N2].dtype()) + (-input1[N0, N1, N2]).call(`exp`));mediate1[N0, N1, N2] = input1[N0, N1, N2] * mediate0[N0, N1, N2];output0[N0, N1, N2] = mediate1[N0, N1, N2] * input2[N0, N1, N2];\", input_dict={ \"input1\" : { \"dtype\" : \"float16\", \"shape\" : [16, 512, 22016]} , \"input2\" : { \"dtype\" : \"float16\", \"shape\" : [16, 512, 22016]} }) ## @: ", + "Sigmoid_Multiply_Multiply", + [ + [ + 184, + 0 + ], + [ + 184, + 0 + ], + [ + 182, + 0 + ] + ] + ], + [ + 185, + " - einstein_v2(\" mediate0[N0, N1] = input0[N0 // 512 , N0 % 512, N1] where N0 in 8192;output0[N0, N1, N2, N3] = mediate0[(N0 * 16 + N2) // 16 * 16 + (N0 * 16 + N2) % 8 * 2 + (N1 * 16 + N3) % 16 // 8, (N1 * 16 + N3) // 16 * 16 + (N0 * 16 + N2) % 16 // 8 * 8 + (N1 * 16 + N3) % 8] where N0 in 512, N1 in 1376, N2 in 16, N3 in 16;\", input_dict={ \"input0\" : { \"dtype\" : \"float16\", \"shape\" : [16, 512, 22016]} }) ", + "Permutate", + [ + [ + 188, + 0 + ] + ] + ], + [ + 186, + " - einstein_v2(\" temp0[M, K] +=! input0[M//1024%16, M//2%512, K] where M in 16384; output0[M, N] +=! temp0[M, K] * input1[K, N]; \", input_dict={ \"input0\" : { \"dtype\" : \"float16\", \"shape\" : [16, 512, 22016]} , \"input1\" : { \"dtype\" : \"float16\", \"shape\" : [22016, 8192]} }) ## @: output_layout=0 |skip", + "LayoutDot", + [ + [ + 185, + 0 + ], + [ + 26, + 0 + ] + ] + ], + [ + 207, + " - einstein_v2(\" mediate0[N0, N1, N2] = input0[0] where N0 in 16, N1 in 512, N2 in 8192; mediate1[N0, N1, N2] = input1[N0, N1, N2] + input2[N0, N1, N2]; output0[N0, N1, N2] = mediate1[N0, N1, N2].cast(`float32`);output1[N0, N1, N2] = output0[N0, N1, N2].call(`pow`, [mediate0[N0, N1, N2]]);\", input_dict={ \"input0\" : { \"dtype\" : \"float32\", \"shape\" : [1]} , \"input1\" : { \"dtype\" : \"float16\", \"shape\" : [16, 512, 8192]} , \"input2\" : { \"dtype\" : \"float16\", \"shape\" : [16, 512, 8192]} }, extra_outputs=[\"output0\", \"output1\"]) ## @: ", + "Add_Convert_Broadcast_Power", + [ + [ + 19, + 0 + ], + [ + 206, + 0 + ], + [ + 186, + 0 + ] + ] + ], + [ + 187, + " - einstein_v2(\" mediate0[N0, N1] = input0[0] where N0 in 16, N1 in 512; mediate1[N0, N1] +=! input1[N0, N1, N2];output0[N0, N1] = mediate1[N0, N1] / mediate0[N0, N1];\", input_dict={ \"input0\" : { \"dtype\" : \"float32\", \"shape\" : [1]} , \"input1\" : { \"dtype\" : \"float32\", \"shape\" : [16, 512, 8192]} }) ## @: ", + "Sum_Broadcast_Divide", + [ + [ + 152, + 0 + ], + [ + 207, + 1 + ] + ] + ], + [ + 209, + " - einstein_v2(\" mediate0[N0, N1, N2] = input0[N2] where N0 in 16, N1 in 512; mediate1[N0] = input1[0] where N0 in 1; mediate2[N0, N1, N2] = mediate1[N2] where N0 in 16, N1 in 512; mediate3[N0, N1, N2] = input2[N0, N1] where N2 in 1; mediate4[N0, N1, N2] = mediate3[N0, N1, N2] + mediate2[N0, N1, N2]; mediate5[N0, N1, N2] = mediate4[N0, N1, N2].call(`sqrt`); mediate6[N0, N1] = mediate5[N0, N1, 0] ; mediate7[N0, N1, N2] = mediate6[N0, N1] where N2 in 8192; mediate8[N0, N1, N2] = input3[N0, N1, N2] / mediate7[N0, N1, N2];mediate9[N0, N1, N2] = mediate8[N0, N1, N2].cast(`float16`);output0[N0, N1, N2] = mediate0[N0, N1, N2] * mediate9[N0, N1, N2];\", input_dict={ \"input0\" : { \"dtype\" : \"float16\", \"shape\" : [8192]} , \"input1\" : { \"dtype\" : \"float32\", \"shape\" : [1]} , \"input2\" : { \"dtype\" : \"float32\", \"shape\" : [16, 512]} , \"input3\" : { \"dtype\" : \"float32\", \"shape\" : [16, 512, 8192]} }) ## @: ", + "Reshape_Reshape_Broadcast_Add_Sqrt_Reshape_Broadcast_Divide_Convert_Broadcast_Multiply", + [ + [ + 3, + 0 + ], + [ + 5, + 0 + ], + [ + 187, + 0 + ], + [ + 207, + 0 + ] + ] + ], + [ + 166, + "", + "Result", + [ + [ + 209, + 0 + ] + ] + ] +] \ No newline at end of file diff --git a/src/nnfusion/core/kernels/cuda_gpu/cuda_langunit.cpp b/src/nnfusion/core/kernels/cuda_gpu/cuda_langunit.cpp index 9eb570935..6241c823b 100644 --- a/src/nnfusion/core/kernels/cuda_gpu/cuda_langunit.cpp +++ b/src/nnfusion/core/kernels/cuda_gpu/cuda_langunit.cpp @@ -59,6 +59,11 @@ CUDA_UNSUPPORTED_HALF_MATH_UNARY(herf, erf) LU_DEFINE(macro::TVM_PACK_VALUES, R"( +__device__ __half max(__half a, __half b) +{ + return (float)a > (float)b ? a : b; +} + inline __device__ longlong4 make_int8(int x0, int x1, int x2, int x3, int x4, int x5, int x6, int x7) { int2 i0 = make_int2(x0, x1); int2 i1 = make_int2(x2, x3); diff --git a/src/nnfusion/core/operators/generic_op/generic_op_define/GatherElements.cpp b/src/nnfusion/core/operators/generic_op/generic_op_define/GatherElements.cpp new file mode 100644 index 000000000..42bfd99ad --- /dev/null +++ b/src/nnfusion/core/operators/generic_op/generic_op_define/GatherElements.cpp @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#include "nnfusion/core/operators/generic_op/generic_op.hpp" + +REGISTER_OP(GatherElements) + .attr("axis", 0) + .infershape([](std::shared_ptr gnode) -> void { + NNFUSION_CHECK(gnode->get_input_size() == 2); + const nnfusion::Shape& input_shape_0 = gnode->get_input_shape(0); + const nnfusion::Shape& input_shape_1 = gnode->get_input_shape(1); + + gnode->set_output_type_and_shape(0, gnode->get_input_element_type(0), input_shape_1); + }) + .translate_v2([](std::shared_ptr curr) -> std::string { + auto generic_op = std::dynamic_pointer_cast(curr->get_op_ptr()); + int axis = generic_op->localOpConfig.getRoot()["axis"]; + // e.g. Antares type is int64 rather than C++'s int64_t + std::string dtype; + bool ret = element::Type::nnfusion_element_type_to_dtype_string(curr->get_input_element_type(1), + dtype); + NNFUSION_CHECK(ret); + + auto ir_template = + R"( @output0@@output0_layout@ = @input0@[@input0_layout_left@@input1@@input1_layout@.when(@input1@@input1_layout@ >= 0, @input1@@input1_layout@ + const(@gather_dim@).cast(@input1@@input1_layout@.dtype()))@input0_layout_right@]; )"; + + auto output0_shape = curr->get_output_shape(0); + auto output0_layout = op::create_layout_from_dims(output0_shape); + auto input0_shape = curr->get_input_shape(0); + auto input1_shape = curr->get_input_shape(1); + std::string input0_layout_left; + std::string input0_layout_right; + std::vector input1_layout; + for (size_t d = 0; d < axis; ++d) + { + input0_layout_left += output0_layout[d] + ", "; + } + + for (size_t d = 0; d < input1_shape.size(); ++d) + { + input1_layout.push_back(output0_layout[d]); + } + + for (size_t d = axis + 1; d < input0_shape.size(); ++d) + { + input0_layout_right += ", " + output0_layout[d]; + } + + input1_layout = input1_layout.empty() ? std::vector({"0"}) : input1_layout; + output0_layout = output0_layout.empty() ? std::vector({"N0"}) : output0_layout; + op::OpConfig::any op_config; + op_config["output0_layout"] = vector_to_string>(output0_layout); + op_config["input0_layout_left"] = input0_layout_left; + op_config["input1_layout"] = vector_to_string>(input1_layout); + op_config["input0_layout_right"] = input0_layout_right; + op_config["gather_dim"] = std::to_string(input0_shape[axis]); + + return op::create_code_from_template(ir_template, op_config); + }); \ No newline at end of file diff --git a/src/nnfusion/core/operators/generic_op/generic_op_define/memfusion_new_ops.cpp b/src/nnfusion/core/operators/generic_op/generic_op_define/memfusion_new_ops.cpp index b49af6dcb..47dcc1445 100644 --- a/src/nnfusion/core/operators/generic_op/generic_op_define/memfusion_new_ops.cpp +++ b/src/nnfusion/core/operators/generic_op/generic_op_define/memfusion_new_ops.cpp @@ -721,3 +721,89 @@ REGISTER_OP(LayoutBMM) } return ir; }); + +REGISTER_OP(QuantLinear) + .attr("bits", 4) + .attr("transpose_A", false) + .attr("transpose_B", true) + .infershape( + [](std::shared_ptr gnode) -> void + { + auto generic_op = + std::dynamic_pointer_cast(gnode->get_op_ptr()); + bool trans_a = generic_op->localOpConfig.getRoot()["transpose_A"]; + bool trans_b = generic_op->localOpConfig.getRoot()["transpose_B"]; + NNFUSION_CHECK(trans_a == false) << "Currently only support non-transpose A"; + NNFUSION_CHECK(trans_b == true) << "Currently only support transpose B"; + NNFUSION_CHECK(4 == gnode->get_input_size()); + // input 0 shape is B, S, K, input 1 is K, N + // output sahpe is B, S, N + auto input_shape = nnfusion::Shape(gnode->get_input_shape(0)); + auto qweight_shape = nnfusion::Shape(gnode->get_input_shape(1)); + NNFUSION_CHECK(input_shape.size() == 2 || input_shape.size() == 3); + NNFUSION_CHECK(qweight_shape.size() == 2); + if (input_shape.size() == 2) + { + nnfusion::Shape output_shape{trans_a ? input_shape[1] : input_shape[0], + trans_b ? qweight_shape[0] : qweight_shape[1]}; + gnode->set_output_type_and_shape(0, gnode->get_input_element_type(0), output_shape); + } + else if (input_shape.size() == 3) + { + nnfusion::Shape output_shape{input_shape[0], + trans_a ? input_shape[2] : input_shape[1], + trans_b ? qweight_shape[0] : qweight_shape[1]}; + gnode->set_output_type_and_shape(0, gnode->get_input_element_type(0), output_shape); + } + + // print input0 shape and input1 shape + NNFUSION_LOG(INFO) << "input0 shape is " << gnode->get_input_shape(0); + NNFUSION_LOG(INFO) << "input1 shape is " << gnode->get_input_shape(1); + NNFUSION_LOG(INFO) << "output shape is " << gnode->get_output_shape(0); + }) + .translate_v2( + [](std::shared_ptr curr) -> std::string + { + auto _op = static_pointer_cast(curr->get_op_ptr()); + NNFUSION_CHECK_NOT_NULLPTR(_op) + << "Node type is not " << curr->get_op_ptr()->get_op_type(); + auto input_shape = curr->get_input_shape(0); + auto qweight_shape = curr->get_input_shape(1); + auto scales_shape = curr->get_input_shape(2); + auto zeros_shape = curr->get_input_shape(3); + + auto ir_template = + R"( @output0@@output0_layout@ = @input0@@input0_layout@ + @input1@@input1_layout@ + @input2@@input2_layout@ + @input3@@input3_layout@; )"; + + vector input_layout, qweight_layout, scales_layout, zeros_layout, output_layout; + + for (size_t i = 0; i + 2 < qweight_shape.size(); i++) + { + input_layout.push_back("S" + std::to_string(i)); + output_layout.push_back("S" + std::to_string(i)); + } + + output_layout.push_back("N"); + output_layout.push_back("M"); + input_layout.push_back(_op->get_transpose_A() ? "K" : "N"); + input_layout.push_back(_op->get_transpose_A() ? "N" : "K"); + qweight_layout.push_back(_op->get_transpose_B() ? "M" : "K"); + qweight_layout.push_back(_op->get_transpose_B() ? "K" : "M"); + scales_layout.push_back("M"); + zeros_layout.push_back("M"); + + for (size_t i = 0; i + 2 < input_shape.size(); i++) + { + qweight_layout.push_back("E" + std::to_string(i)); + output_layout.push_back("E" + std::to_string(i)); + } + + op::OpConfig::any op_config; + op_config["input0_layout"] = nnfusion::vector_to_string(input_layout); + op_config["input1_layout"] = nnfusion::vector_to_string(qweight_layout); + op_config["input2_layout"] = nnfusion::vector_to_string(scales_layout); + op_config["input3_layout"] = nnfusion::vector_to_string(zeros_layout); + op_config["output0_layout"] = nnfusion::vector_to_string(output_layout); + auto ir = op::create_code_from_template(ir_template, op_config); + return ir; + }); diff --git a/src/nnfusion/frontend/onnx_import/CMakeLists.txt b/src/nnfusion/frontend/onnx_import/CMakeLists.txt index 2032ea44e..8afe2c5f3 100644 --- a/src/nnfusion/frontend/onnx_import/CMakeLists.txt +++ b/src/nnfusion/frontend/onnx_import/CMakeLists.txt @@ -88,6 +88,8 @@ add_library(onnx_import STATIC op/gru.cpp op/einsum.hpp op/einsum.cpp + op/quant_linear.cpp + op/quant_linear.hpp # ${ops_source} util/broadcasting.cpp util/graph_convert.cpp diff --git a/src/nnfusion/frontend/onnx_import/op/gather.cpp b/src/nnfusion/frontend/onnx_import/op/gather.cpp index 59abd5e84..8a4785013 100644 --- a/src/nnfusion/frontend/onnx_import/op/gather.cpp +++ b/src/nnfusion/frontend/onnx_import/op/gather.cpp @@ -155,6 +155,27 @@ namespace nnfusion return {{node_proto.output(0), generic_gnode}}; } + + NamedNodeVector + TranslateGatherElementsOp(const onnx::NodeProto& node_proto, + const NodeMap& all_ng_nodes, + std::shared_ptr m_graph) + { + auto input_indexes = GetAllInputIndex(all_ng_nodes, node_proto); + + Node node(node_proto); + auto axis = node.get_attribute_value("axis", 0); + axis += axis < 0 ? input_indexes[0].get_shape().size() : 0; + + nnfusion::op::OpConfig::any myConfig; + myConfig["axis"] = axis; + + auto generic_op = std::make_shared( + node_proto.output(0), "GatherElements", myConfig); + auto generic_gnode = m_graph->add_node_and_edge(generic_op, input_indexes); + + return {{node_proto.output(0), generic_gnode}}; + } } // namespace set_11 diff --git a/src/nnfusion/frontend/onnx_import/op/gather.hpp b/src/nnfusion/frontend/onnx_import/op/gather.hpp index db6ffbe77..a74cfec5b 100644 --- a/src/nnfusion/frontend/onnx_import/op/gather.hpp +++ b/src/nnfusion/frontend/onnx_import/op/gather.hpp @@ -55,6 +55,11 @@ namespace nnfusion TranslateGatherNDGradOp(const onnx::NodeProto& node_proto, const NodeMap& all_ng_nodes, std::shared_ptr m_graph); + + NamedNodeVector + TranslateGatherElementsOp(const onnx::NodeProto& node_proto, + const NodeMap& all_ng_nodes, + std::shared_ptr m_graph); } namespace set_12 @@ -65,8 +70,9 @@ namespace nnfusion namespace set_13 { using set_1::TranslateGatherOp; + using set_11::TranslateGatherElementsOp; using set_11::TranslateGatherNDOp; - } + } // namespace set_13 } //namespace onnx_import diff --git a/src/nnfusion/frontend/onnx_import/op/quant_linear.cpp b/src/nnfusion/frontend/onnx_import/op/quant_linear.cpp new file mode 100644 index 000000000..728ba6d3a --- /dev/null +++ b/src/nnfusion/frontend/onnx_import/op/quant_linear.cpp @@ -0,0 +1,63 @@ +//***************************************************************************** +// Copyright 2017-2020 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +//---------------------------------------------------------------------------------------------- +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. +//---------------------------------------------------------------------------------------------- + +#include "quant_linear.hpp" +#include +#include +#include "../util/util.hpp" +#include "nnfusion/core/operators/generic_op/generic_op.hpp" + +namespace nnfusion +{ + namespace frontend + { + namespace onnx_import + { + namespace set_7 + { + NamedNodeVector + TranslateQuantLinearOp(const onnx::NodeProto& node_proto, + const NodeMap& all_ng_nodes, + std::shared_ptr m_graph) + { + NNFUSION_LOG(INFO) << "Translating QuantLinear"; + + auto input_indexes = GetAllInputIndex(all_ng_nodes, node_proto); + auto A = input_indexes[0]; + auto Qweight = input_indexes[1]; + auto Scales = input_indexes[1]; + auto Zeros = input_indexes[1]; + + Node node(node_proto); + nnfusion::op::OpConfig::any myConfig; + auto generic_op = std::make_shared( + node_proto.name(), "QuantLinear", myConfig); + auto generic_gnode = m_graph->add_node_and_edge(generic_op, input_indexes); + return NamedNodeVector{{node_proto.output(0), generic_gnode}}; + } + + } // namespace set_1 + + } //namespace onnx_import + + } // namespace frontend + +} // namespace nnfusion diff --git a/src/nnfusion/frontend/onnx_import/op/quant_linear.hpp b/src/nnfusion/frontend/onnx_import/op/quant_linear.hpp new file mode 100644 index 000000000..cdbcb666e --- /dev/null +++ b/src/nnfusion/frontend/onnx_import/op/quant_linear.hpp @@ -0,0 +1,65 @@ +//***************************************************************************** +// Copyright 2017-2020 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +//---------------------------------------------------------------------------------------------- +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. +//---------------------------------------------------------------------------------------------- + +#pragma once + +#include "core/node.hpp" + +namespace nnfusion +{ + namespace frontend + { + namespace onnx_import + { + namespace set_7 + { + NamedNodeVector + TranslateQuantLinearOp(const onnx::NodeProto& node_proto, + const NodeMap& all_ng_nodes, + std::shared_ptr m_graph); + + } // namespace set_7 + + namespace set_1 + { + using set_7::TranslateQuantLinearOp; + } // namespace set_9 + + namespace set_9 + { + using set_7::TranslateQuantLinearOp; + } // namespace set_9 + + namespace set_11 + { + using set_7::TranslateQuantLinearOp; + } // namespace set_11 + + namespace set_13 + { + using set_7::TranslateQuantLinearOp; + } // namespace set_13 + + } //namespace onnx_import + + } // namespace frontend + +} // namespace nnfusion diff --git a/src/nnfusion/frontend/onnx_import/op/reshape.cpp b/src/nnfusion/frontend/onnx_import/op/reshape.cpp index c4e932274..d173b1858 100644 --- a/src/nnfusion/frontend/onnx_import/op/reshape.cpp +++ b/src/nnfusion/frontend/onnx_import/op/reshape.cpp @@ -45,6 +45,9 @@ namespace nnfusion NNFUSION_CHECK(std::count(output_shape.begin(), output_shape.end(), -1) <= 1) << "Shape should have at most 1 dynamic dimension"; + NNFUSION_LOG(INFO) << "Reshape Name " << node_proto.name() << node_proto.output(0) << " " + << join(input_shape) + << " -> " << join(output_shape); size_t num_input_elements = nnfusion::shape_size(input_shape); // infer the dimension of -1 and 0 @@ -69,7 +72,7 @@ namespace nnfusion if (dynamic_dim == output_shape.end()) { NNFUSION_CHECK(static_size == num_input_elements) - << "Reshape size doesn\'t match"; + << "Reshape size doesn\'t match, static_size " << static_size << " vs " << num_input_elements; } else { diff --git a/src/nnfusion/frontend/onnx_import/op/transpose.cpp b/src/nnfusion/frontend/onnx_import/op/transpose.cpp index 7fc37ae8d..de94ee843 100644 --- a/src/nnfusion/frontend/onnx_import/op/transpose.cpp +++ b/src/nnfusion/frontend/onnx_import/op/transpose.cpp @@ -50,7 +50,12 @@ namespace nnfusion // std::iota(perm.rbegin(), perm.rend(), 0); // } AxisVector ng_axis_order(perm.begin(), perm.end()); - + // print node name + NNFUSION_LOG(INFO) << "Transpose: " << node_proto.output(0); + NNFUSION_LOG(INFO) << "Transpose: " << node_proto.name(); + // print input + NNFUSION_LOG(INFO) << data.gnode->get_name(); + NNFUSION_LOG(INFO) << data.gnode->get_output_shape(0); auto out_gnode = nnfusion::graph::numpy_transpose(data.gnode, ng_axis_order, data.index); out_gnode->get_op_ptr()->set_name(node_proto.output(0)); diff --git a/src/nnfusion/frontend/onnx_import/ops_bridge.cpp b/src/nnfusion/frontend/onnx_import/ops_bridge.cpp index 5d13e87e0..ea86379b9 100644 --- a/src/nnfusion/frontend/onnx_import/ops_bridge.cpp +++ b/src/nnfusion/frontend/onnx_import/ops_bridge.cpp @@ -66,6 +66,7 @@ #include "op/one_hot.hpp" #include "op/pad.hpp" #include "op/pool.hpp" +#include "op/quant_linear.hpp" #include "op/range.hpp" #include "op/reciprocal.hpp" #include "op/reduce.hpp" @@ -88,7 +89,6 @@ #include "op/unsqueeze.hpp" #include "op/where.hpp" #include "ops_bridge.hpp" - namespace nnfusion { namespace frontend @@ -254,6 +254,8 @@ namespace nnfusion REGISTER_OPERATOR("GatherND", 12, TranslateGatherNDOp); REGISTER_OPERATOR("GatherND", 13, TranslateGatherNDOp); REGISTER_OPERATOR("GatherNDGrad", 11, TranslateGatherNDGradOp); + REGISTER_OPERATOR("GatherElements", 11, TranslateGatherElementsOp); + REGISTER_OPERATOR("GatherElements", 13, TranslateGatherElementsOp); REGISTER_OPERATOR("Gelu", 1, TranslateUnaryOp); REGISTER_OPERATOR("GlobalAveragePool", 1, @@ -344,6 +346,11 @@ namespace nnfusion REGISTER_OPERATOR("Pow", 12, TranslateBinaryOp); REGISTER_OPERATOR("Pow", 13, TranslateBinaryOp); REGISTER_OPERATOR("Pow", 15, TranslateBinaryOp); + REGISTER_DOMAIN_OPERATOR("nnfusion", "QuantLinear", 1, TranslateQuantLinearOp); + REGISTER_DOMAIN_OPERATOR("nnfusion", "QuantLinear", 7, TranslateQuantLinearOp); + REGISTER_DOMAIN_OPERATOR("nnfusion", "QuantLinear", 9, TranslateQuantLinearOp); + REGISTER_DOMAIN_OPERATOR("nnfusion", "QuantLinear", 11, TranslateQuantLinearOp); + REGISTER_DOMAIN_OPERATOR("nnfusion", "QuantLinear", 13, TranslateQuantLinearOp); //REGISTER_OPERATOR("PRelu", 1, prelu); REGISTER_OPERATOR("Range", 11, TranslateRangeOp); REGISTER_OPERATOR("Reciprocal", 1, TranslateReciprocalOp); diff --git a/src/nnfusion/frontend/onnx_import/util/graph_convert.cpp b/src/nnfusion/frontend/onnx_import/util/graph_convert.cpp index 76984839b..5f5adc12a 100644 --- a/src/nnfusion/frontend/onnx_import/util/graph_convert.cpp +++ b/src/nnfusion/frontend/onnx_import/util/graph_convert.cpp @@ -521,7 +521,10 @@ namespace nnfusion NamedNodeVector GraphConvert::convert_node(const onnx::NodeProto& node_proto) { - NNFUSION_LOG(DEBUG) << "convert node: " << node_proto.name(); + NNFUSION_LOG(INFO) << "convert node: " << node_proto.name(); + // op_type + NNFUSION_LOG(INFO) << "op_type: " << node_proto.op_type()<< " domain: " << node_proto.domain(); + const auto& convert_func = get_convert_func(node_proto.op_type(), node_proto.domain()); @@ -551,6 +554,14 @@ namespace nnfusion const ConvertFunc& GraphConvert::get_convert_func(const std::string& name, const std::string& domain) const { + if (m_domain_convert_func_map.find(domain) == m_domain_convert_func_map.end() ) + { + NNFUSION_LOG(NNFUSION_WARNING) << "No domain: " << domain << " found"; + } + if (m_domain_convert_func_map.at(domain).find(name) == + m_domain_convert_func_map.at(domain).end()){ + NNFUSION_LOG(NNFUSION_WARNING) << "No op: " << name << " found"; + } if (m_domain_convert_func_map.find(domain) == m_domain_convert_func_map.end() || m_domain_convert_func_map.at(domain).find(name) == m_domain_convert_func_map.at(domain).end())