Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
add global outputs of graph to groups
Browse files Browse the repository at this point in the history
  • Loading branch information
BiynXu authored and 6clc committed May 5, 2023
1 parent 179cb6e commit 2fe484a
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 5 deletions.
2 changes: 1 addition & 1 deletion cinn/hlir/framework/op_lowering.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1326,7 +1326,7 @@ void OpLowerer::IRSchedule(ir::IRSchedule& ir_sch,
}

VLOG(3) << "Before Sync IRLowerOp schedule, ir is: \n" << ir_sch.GetModule().GetExprs().at(0);
SyncThreadWithShared(ir_sch, nodes_inline, nodes_set, this->shape_dict_, tensor_map);
SyncThreadWithShared(ir_sch, nodes_inline, nodes_set, this->shape_dict_, tensor_map, group);
VLOG(4) << "After IRSchedule, ir is: \n" << ir_sch.GetModule().GetExprs().at(0);
}

Expand Down
5 changes: 3 additions & 2 deletions cinn/hlir/framework/op_lowering_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1407,7 +1407,8 @@ void SyncThreadWithShared(ir::IRSchedule& ir_sch,
const std::unordered_set<Node*>& nodes_inline,
const std::unordered_set<Node*>& nodes_set,
const absl::flat_hash_map<std::string, shape_t>& shape_dict,
const std::unordered_map<std::string, ir::Tensor>& tensor_map) {
const std::unordered_map<std::string, ir::Tensor>& tensor_map,
const GroupPtr& group) {
auto exprs_inorder = ir_sch.GetAllBlocks();
auto node_data_set = GetNodeDataSet(nodes_set);
auto& op_pattern_dict = Operator::GetAttrs<OpPatternKind>("OpPattern");
Expand Down Expand Up @@ -1469,7 +1470,7 @@ void SyncThreadWithShared(ir::IRSchedule& ir_sch,
do_set_buffer_to_shared = true;
}
}
if (do_set_buffer_to_shared) {
if (do_set_buffer_to_shared && group->output_nodes.find(node) == group->output_nodes.end()) {
auto block = ir_sch.GetBlock(node_data->id());
ir_sch.SetBuffer(block, "shared", true);
}
Expand Down
3 changes: 2 additions & 1 deletion cinn/hlir/framework/op_lowering_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ void SyncThreadWithShared(ir::IRSchedule& ir_sch,
const std::unordered_set<Node*>& nodes_inline,
const std::unordered_set<Node*>& nodes_set,
const absl::flat_hash_map<std::string, shape_t>& shape_dict,
const std::unordered_map<std::string, ir::Tensor>& tensor_map);
const std::unordered_map<std::string, ir::Tensor>& tensor_map,
const GroupPtr& group);

} // namespace framework
} // namespace hlir
Expand Down
16 changes: 15 additions & 1 deletion cinn/hlir/pass/fusion_merge_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ using ConditionFunction = std::function<bool(const FusionHelperBase*, const Grou
// code generation.
class FusionMergePassHelper : public FusionHelperBase {
public:
FusionMergePassHelper(const Graph* graph) : FusionHelperBase(graph) {
FusionMergePassHelper(const Graph* graph) : FusionHelperBase(graph), graph_output_node_data_(graph->outputs) {
fusion_groups_ = graph->fusion_groups;
// init fusion relation.
InitFusionRelation();
Expand All @@ -56,6 +56,7 @@ class FusionMergePassHelper : public FusionHelperBase {
GroupList operator()() {
// run fusion merge untill no update.
DoFusionMerge();
AddGlobalOutputNodesToGroups();
for (auto& group : fusion_groups_) {
VLOG(3) << "Fusion Group -> " << group->group_id;
for (auto& sub_group : group->fused_sub_groups) {
Expand All @@ -72,6 +73,18 @@ class FusionMergePassHelper : public FusionHelperBase {
}

private:
void AddGlobalOutputNodesToGroups() {
for (auto group : fusion_groups_) {
for (const auto& output_node_data : graph_output_node_data_) {
Node* node = output_node_data->source_node.get();
std::unordered_set<Node*> node_set = group->NodeSet();
if (node_set.find(node) != node_set.end()) {
group->output_nodes.insert(node);
}
}
}
}

void DoFusionMerge() {
VLOG(3) << "DoFusionMerge...!";
while (DoHorizontalFusion()) {
Expand Down Expand Up @@ -981,6 +994,7 @@ class FusionMergePassHelper : public FusionHelperBase {
}

GroupList fusion_groups_;
const std::vector<NodeData*>& graph_output_node_data_;
std::unordered_map<GroupPtr, int, Hasher, Comparator> fusion_groups_index_;
std::unordered_map<NodeData*, std::unordered_set<GroupPtr, Hasher, Comparator>> input_to_consumers_;

Expand Down

0 comments on commit 2fe484a

Please sign in to comment.