Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support folding onnx>2GB with onnxruntime #528

Merged
merged 4 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions models/pytorch2onnx/ort_run_frozen.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ def check_shape(shape):
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
if args.optimized_model_filepath != '':
sess_options.optimized_model_filepath = args.optimized_model_filepath
sess_options.add_session_config_entry(
"session.optimized_model_external_initializers_file_name", os.path.basename(args.optimized_model_filepath) + ".data"
)
sess_options.add_session_config_entry("session.optimized_model_external_initializers_min_size_in_bytes", "100")

for k, v in args.symbolic_dims.items():
sess_options.add_free_dimension_override_by_name(k, int(v))
Expand Down
1 change: 1 addition & 0 deletions src/nnfusion/common/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
#include "nnfusion/core/operators/op_define/result.hpp"
#include "nnfusion/core/operators/op_define/reverse.hpp"
#include "nnfusion/core/operators/op_define/reverse_sequence.hpp"
#include "nnfusion/core/operators/op_define/round.hpp"
#include "nnfusion/core/operators/op_define/rsqrt.hpp"
#include "nnfusion/core/operators/op_define/select.hpp"
#include "nnfusion/core/operators/op_define/select_and_scatter.hpp"
Expand Down
4 changes: 2 additions & 2 deletions src/nnfusion/core/kernels/cuda_gpu/kernels/gather_1d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ LanguageUnit_p cuda::Gather1D::emit_function_body()
}

lu << "int64_t gather_i = __ldg(indices + indices_i);\n";
lu << "if (gather_i < 0) gather_i += " << gather_dim_size <<";\n";
lu << "if (gather_i < 0) gather_i += " << gather_dim_size << ";\n";
lu << "if (gather_i >= " << gather_dim_size << ")\n"
<< " out[i] = 0;\n"
<< "else\n";
Expand Down Expand Up @@ -194,7 +194,7 @@ LanguageUnit_p cuda::Gather1DGrad::emit_function_body()
}

lu << "int64_t gather_i = __ldg(indices + indices_i);\n";
lu << "if (gather_i < 0) gather_i += " << gather_dim_size <<";\n";
lu << "if (gather_i < 0) gather_i += " << gather_dim_size << ";\n";
lu << "if (gather_i < " << gather_dim_size << ")\n";
lu.block_begin();
{
Expand Down
1 change: 1 addition & 0 deletions src/nnfusion/core/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ set(SRC
op_define/result.cpp
op_define/reverse_sequence.cpp
op_define/reverse.cpp
op_define/round.cpp
op_define/rsqrt.cpp
op_define/select_and_scatter.cpp
op_define/select.cpp
Expand Down
7 changes: 6 additions & 1 deletion src/nnfusion/core/operators/generic_op/generic_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,10 +313,15 @@ namespace nnfusion
{
config[alias_name + "_dtype"] = "int64";
}
else if (d_type == element::u8)
{
// hack!!!
config[alias_name + "_dtype"] = "int8";
}
else
{
NNFUSION_CHECK_FAIL()
<< "Unhandled type: " << d_type
<< "Unhandled type for " << input_name << ": " << d_type
<< ", antares currently supports int8/16/32/64, float16/32/64";
}
auto shape = tensor->get_shape();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ static const std::unordered_map<std::string, element_op> ElementOpMap = {
{"Sin", element_op("sin", "")},
{"Sinh", element_op("sinh", "")},
{"Sqrt", element_op("sqrt", "")},
{"Round", element_op("round", "x0.call(`round`)")},
{"Rsqrt", element_op("rsqrt", "")},
{"Tan", element_op("tan", "")},
{"Tanh", element_op("tanh", "")},
Expand Down Expand Up @@ -196,6 +197,7 @@ REGISTER_ELEM_OP(Relu)
REGISTER_ELEM_OP(Relu6)
REGISTER_ELEM_OP(ReluBackprop)
REGISTER_ELEM_OP(Relu6Backprop)
REGISTER_ELEM_OP(Round)
REGISTER_ELEM_OP(Sigmoid)
REGISTER_ELEM_OP(SigmoidBackprop)
REGISTER_ELEM_OP(Equal)
Expand Down
21 changes: 10 additions & 11 deletions src/nnfusion/core/operators/generic_op/generic_op_define/Trilu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,18 @@

REGISTER_OP(Trilu)
.infershape([](std::shared_ptr<graph::GNode> curr) -> void {
curr->set_output_type_and_shape(0, curr->get_input_element_type(0), curr->get_input_shape(0));
})
curr->set_output_type_and_shape(
0, curr->get_input_element_type(0), curr->get_input_shape(0));
})
.translate_v2([](std::shared_ptr<graph::GNode> curr) -> std::string {
auto input_shape_0 = curr->get_input_shape(0);
assert(input_shape_0.size() >= 2);
std::string k_str = "";
if(curr->get_input_size() == 2)
k_str = "+ input1[0]";
if (curr->get_input_size() == 2)
k_str = "+ input1[0]";
auto op = static_pointer_cast<nnfusion::op::GenericOp>(curr->get_op_ptr());
auto& cfg = op->localOpConfig.getRoot();
bool upper = cfg["upper"].is_null()?true:int64_t(cfg["upper"])!=0;
bool upper = cfg["upper"].is_null() ? true : int64_t(cfg["upper"]) != 0;
auto input_layout = op::create_layout_from_dims(input_shape_0);
auto dim_a = input_layout[input_layout.size() - 2];
auto dim_b = input_layout[input_layout.size() - 1];
Expand All @@ -28,13 +29,11 @@ REGISTER_OP(Trilu)
element::Type::nnfusion_element_type_to_dtype_string(curr->get_element_type(), dtype);
NNFUSION_CHECK(ret);

std::string condition = upper?dim_b+">="+dim_a+k_str:dim_a+k_str+">="+dim_b;
std::string condition = upper ? dim_b + ">=" + dim_a + k_str : dim_a + k_str + ">=" + dim_b;

auto expression = op::create_code_from_template(
"@output0@[@input_layout@] = @input0@[@input_layout@].when(@condition@, const(0).cast(`@dtype@`));", {
{"input_layout", join(input_layout)},
{"condition", condition},
{"dtype", dtype}
});
"@output0@[@input_layout@] = @input0@[@input_layout@].when(@condition@, "
"const(0).cast(`@dtype@`));",
{{"input_layout", join(input_layout)}, {"condition", condition}, {"dtype", dtype}});
return expression;
});
4 changes: 2 additions & 2 deletions src/nnfusion/core/operators/op_define/fused.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ namespace nnfusion
std::shared_ptr<graph::GNode> fused_node);
std::string get_fused_ir2() { return fused_op_ir2; };
std::string get_plan_rule();

bool get_is_memcpy() { return is_memcpy; }
protected:
void assemble_inputs_and_outputs();

Expand All @@ -41,4 +41,4 @@ namespace nnfusion
bool is_memcpy;
};
}
}
}
26 changes: 26 additions & 0 deletions src/nnfusion/core/operators/op_define/round.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************

// Microsoft (c) 2020, NNFusion Team

#include "round.hpp"

using namespace nnfusion::op;

Round::Round()
: ElementwiseArithmetic("Round")
{
}
35 changes: 35 additions & 0 deletions src/nnfusion/core/operators/op_define/round.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************

// Microsoft (c) 2020, NNFusion Team

#pragma once

#include "nnfusion/core/operators/util/elementwise_arithmetic.hpp"

namespace nnfusion
{
namespace op
{
/// \brief Elementwise cosine operation.
class Round : public ElementwiseArithmetic
{
public:
/// \brief Constructs a round operation.
Round();
};
}
}
4 changes: 2 additions & 2 deletions src/nnfusion/engine/pass/codegen/cuda_codegen_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -879,8 +879,8 @@ nnfusion::LanguageUnit_p CudaCodegenPass::func_call_codegen(nnfusion::ir::Instru
lu << "Debug(\"" << node_name << ", " << out_name << member_name << "_f32\", "
<< "fp32tensors, \"" << join(kernel->m_context->input_names) << "\", "
<< kernel->m_context->outputs[i]->size(false) << ");\n";
lu << "CUDA_SAFE_CALL(cudaMemset((void*)fp32tensors, 0, "
<< max_tensor_size <<"));\n";
lu << "CUDA_SAFE_CALL(cudaMemset((void*)fp32tensors, 0, " << max_tensor_size
<< "));\n";
}
else if (element::get_backend_cstring(
kernel->m_context->outputs[i]->get_element_type()) == "float")
Expand Down
112 changes: 102 additions & 10 deletions src/nnfusion/engine/pass/graph/register_fusion_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ using namespace nnfusion::kernels;

DEFINE_string(ftune_output_file, "", "the output json file path");
DEFINE_string(ftune_input_file, "", "the input json file path");
DEFINE_bool(fnofuse, false, "Disable element-wise fusion");
DEFINE_string(ffusion_skiplist, "", "List of op types that skips in fusion");
DECLARE_string(fdefault_device);

Expand Down Expand Up @@ -84,6 +85,14 @@ namespace
});
return nodes;
}

string ir_add_tag(const string& ir, const string& tag)
{
if (ir.find("## @:") != string::npos)
return ir + "|" + tag;
else
return ir + "## @: " + tag;
}
}

class RegisterFusionOptimizer
Expand Down Expand Up @@ -138,11 +147,13 @@ class RegisterFusionOptimizer
fuse_from_node(tnode, true);
}
}
inline_lightweighted_ops();
auto groups = extract_fusion_group();
for (auto group : groups)
{
insert_fuse_group(group);
}
if (!FLAGS_fnofuse)
for (auto group : groups)
{
insert_fuse_group(group);
}
auto nodes = nlohmann::json().array();
for (auto& node : find_topo_sort_priority(m_graph))
{
Expand All @@ -151,10 +162,12 @@ class RegisterFusionOptimizer
auto str = nnfusion::op::get_translation_v2(node);
if (skip_ops.count(node->get_op_type()))
{
if (str.find("## @:") != string::npos)
str += "|skip";
else
str += "## @: skip";
str = ir_add_tag(str, "skip");
}
if (node->get_op_type() == "Fused" &&
std::dynamic_pointer_cast<op::Fused>(node->get_op_ptr())->get_is_memcpy())
{
str = ir_add_tag(str, "memcpy");
}
auto edge = nlohmann::json().array();
for (auto& e : node->get_in_edges())
Expand All @@ -173,7 +186,7 @@ class RegisterFusionOptimizer
}

private:
vector<shared_ptr<FuseGroup>> extract_fusion_group()
vector<shared_ptr<FuseGroup>> extract_fusion_group() const
{
unordered_map<int, shared_ptr<FuseGroup>> groups;
vector<shared_ptr<FuseGroup>> result;
Expand All @@ -195,6 +208,85 @@ class RegisterFusionOptimizer
return result;
}

bool is_lightweighted_op(const shared_ptr<GNode>& node)
{
auto type = node->get_op_type();
if (type == "Slice" || type == "Broadcast")
return true;
if (type == "Reshape")
{
auto op = std::dynamic_pointer_cast<op::Reshape>(node->get_op_ptr());
auto order = op->get_input_order();
if (order.empty())
return true;

bool is_lower_dim_kept = order.back() == order.size() - 1;
return is_lower_dim_kept;
}
return false;
}

void inline_lightweighted_ops()
{
// Iterate over all independent groups
// inline first group into second if:
// 1. first group has one output
// 2. first group are all light weighted ops
// 3. all ops not in skip lists
unordered_map<int, shared_ptr<FuseGroup>> map;
vector<shared_ptr<FuseGroup>> groups;
for (auto& tnode : node_list_)
{
if (tnode->node_->get_op_ptr()->is_tensor_op())
continue;
if (tnode->group_id_ < 0)
{
auto f = make_shared<FuseGroup>();
f->nodes.insert(tnode->node_);
groups.push_back(f);
}
else
{
if (!map.count(tnode->group_id_))
{
map[tnode->group_id_] = make_shared<FuseGroup>();
}
map[tnode->group_id_]->nodes.insert(tnode->node_);
}
}
for (auto& kv : map)
groups.push_back(kv.second);

for (auto& group : groups)
{
bool group_is_lightweighted = true;
unordered_set<shared_ptr<GNode>> group_outputs;
for (auto& node : group->nodes)
{
group_is_lightweighted &= is_lightweighted_op(node);
for (auto& edge : node->get_out_edges())
{
if (!group->nodes.count(edge->get_dst()))
group_outputs.insert(edge->get_dst());
}
}
if (group_outputs.size() == 0)
continue;
auto& output_node = *group_outputs.begin();
auto& tag_output_node = node_map_[output_node];
bool op_skip = skip_ops.count(output_node->get_op_type());
for (auto& node : group->nodes)
op_skip |= skip_ops.count(node->get_op_type());
if (group_is_lightweighted && !op_skip && group_outputs.size() == 1)
{
if (tag_output_node->group_id_ < 0)
tag_output_node->group_id_ = cur_group_++;
for (auto& node : group->nodes)
node_map_[node]->group_id_ = tag_output_node->group_id_;
}
}
}

void insert_fuse_group(shared_ptr<FuseGroup> group)
{
// get a meaningful name
Expand Down Expand Up @@ -453,4 +545,4 @@ bool RegisterFusionPass::run_on_graph(std::shared_ptr<Graph>& graph)
applier.apply(FLAGS_ftune_input_file);
NNFUSION_LOG(INFO) << "RegisterFusionPass Done";
return true;
}
}
Loading