Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

reduce fuse reduce #1393

Open
wants to merge 14 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions cinn/hlir/framework/op_lowering.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1225,8 +1225,9 @@ void OpLowerer::IRSchedule(ir::IRSchedule& ir_sch,
}
}

auto masters = GetMasters(node, nodes_inline, nodes_set);
// node can be inline.
if (CanbeInline(node, consumers, reducer, nodes_in_order.front(), group, nodes_set, this->shape_dict_)) {
if (CanbeInline(node, consumers, reducer, masters, group, nodes_set, this->shape_dict_)) {
auto block = ir_sch.GetBlock(GetNodeData(node)->id());
ir::ComputeInlineChecker checker(ir_sch, block);
if (!checker.Check()) {
Expand Down Expand Up @@ -1326,7 +1327,7 @@ void OpLowerer::IRSchedule(ir::IRSchedule& ir_sch,
}

VLOG(3) << "Before Sync IRLowerOp schedule, ir is: \n" << ir_sch.GetModule().GetExprs().at(0);
SyncThreadWithShared(ir_sch, nodes_inline, nodes_set, this->shape_dict_, tensor_map);
SyncThreadWithShared(ir_sch, nodes_inline, nodes_set, this->shape_dict_, tensor_map, group);
VLOG(4) << "After IRSchedule, ir is: \n" << ir_sch.GetModule().GetExprs().at(0);
}

Expand Down
74 changes: 41 additions & 33 deletions cinn/hlir/framework/op_lowering_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -165,11 +165,10 @@ bool IsConstOp(const framework::Node* node) {
}

std::vector<int> GetInputShape(const Node* node, const absl::flat_hash_map<std::string, shape_t>& shape_dict) {
auto producers = GetProducers(node);
CHECK(producers.size());
auto input_data = GetInputNodeData(node);
CHECK(input_data.size());

auto producer_data = GetNodeData(producers.front());
return shape_dict.at(producer_data->id());
return shape_dict.at(input_data.front()->id());
}

std::vector<int> GetOutputShape(const Node* node, const absl::flat_hash_map<std::string, shape_t>& shape_dict) {
Expand Down Expand Up @@ -636,7 +635,7 @@ void LoopAssignReduceWithLast(ir::IRSchedule& ir_sch,
bool CanbeInline(Node* node,
const std::vector<Node*> consumers,
const Node* reducer,
const Node* laster,
const std::unordered_set<Node*> masters,
const GroupPtr& group,
const std::unordered_set<Node*>& nodes_set,
const absl::flat_hash_map<std::string, shape_t>& shape_dict) {
Expand Down Expand Up @@ -678,10 +677,14 @@ bool CanbeInline(Node* node,
return false;
} else {
auto node_shape = GetOutputShape(node, shape_dict);
auto last_shape = GetOutputShape(laster, shape_dict);
if (std::accumulate(node_shape.begin(), node_shape.end(), 1, std::multiplies<int>()) !=
std::accumulate(last_shape.begin(), last_shape.end(), 1, std::multiplies<int>())) {
return true;
auto node_size = std::accumulate(node_shape.begin(), node_shape.end(), 1, std::multiplies<int>());

for (auto master : masters) {
auto master_shape = GetOutputShape(master, shape_dict);
auto master_size = std::accumulate(master_shape.begin(), master_shape.end(), 1, std::multiplies<int>());
if (node_size != master_size) {
return true;
}
}

return false;
Expand Down Expand Up @@ -1313,7 +1316,7 @@ void LoopComputeAt(ir::IRSchedule& ir_sch,
auto& op_pattern_dict = Operator::GetAttrs<OpPatternKind>("OpPattern");
if (!group->output_nodes.count(node)) {
auto block = ir_sch.GetBlock(GetNodeData(node)->id());
ir_sch.SetBuffer(block, "local", true);
ir_sch.SetBuffer(block, "local");
}

if (op_pattern_dict[node->op()] == framework::kReduction) {
Expand Down Expand Up @@ -1370,11 +1373,14 @@ std::unordered_map<std::string, NodeData*> GetNodeDataSet(const std::unordered_s
return node_data_set;
}

Node* GetMaster(Node* node, const std::unordered_set<Node*>& nodes_inline, const std::unordered_set<Node*>& nodes_set) {
std::unordered_set<Node*> GetMasters(Node* node,
const std::unordered_set<Node*>& nodes_inline,
const std::unordered_set<Node*>& nodes_set) {
// find consumer
std::unordered_set<Node*> visited;
std::queue<Node*> candidates;
candidates.push(node);
std::unordered_set<Node*> masters;

while (!candidates.empty()) {
auto candidate = candidates.front();
Expand All @@ -1389,19 +1395,20 @@ Node* GetMaster(Node* node, const std::unordered_set<Node*>& nodes_inline, const
candidates.push(consumer);
visited.insert(consumer);
} else {
return consumer;
masters.insert(consumer);
}
}
}

return nullptr;
return masters;
}

void SyncThreadWithShared(ir::IRSchedule& ir_sch,
const std::unordered_set<Node*>& nodes_inline,
const std::unordered_set<Node*>& nodes_set,
const absl::flat_hash_map<std::string, shape_t>& shape_dict,
const std::unordered_map<std::string, ir::Tensor>& tensor_map) {
const std::unordered_map<std::string, ir::Tensor>& tensor_map,
const GroupPtr& group) {
auto exprs_inorder = ir_sch.GetAllBlocks();
auto node_data_set = GetNodeDataSet(nodes_set);
auto& op_pattern_dict = Operator::GetAttrs<OpPatternKind>("OpPattern");
Expand Down Expand Up @@ -1438,34 +1445,35 @@ void SyncThreadWithShared(ir::IRSchedule& ir_sch,
auto node = node_data->source_node.get();
auto node_shape = shape_dict.at(node_data->id());

auto master = GetMaster(node, nodes_inline, nodes_set);
if (!master) {
auto masters = GetMasters(node, nodes_inline, nodes_set);
if (masters.empty()) {
continue;
}

auto master_data = GetNodeData(master);
auto master_shape = shape_dict.at(master_data->id());
if (op_pattern_dict[master->op()] == framework::kReduction) {
master_shape = shape_dict.at(master->inlinks_in_order()[0]->source()->id());
}
bool do_set_buffer_to_shared = false;
for (auto master : masters) {
auto master_data = GetNodeData(master);
auto master_shape = shape_dict.at(master_data->id());
if (op_pattern_dict[master->op()] == framework::kReduction) {
master_shape = shape_dict.at(master->inlinks_in_order()[0]->source()->id());
}

auto node_size = std::accumulate(node_shape.begin(), node_shape.end(), 1, std::multiplies<int>());
auto master_size = std::accumulate(master_shape.begin(), master_shape.end(), 1, std::multiplies<int>());
auto node_size = std::accumulate(node_shape.begin(), node_shape.end(), 1, std::multiplies<int>());
auto master_size = std::accumulate(master_shape.begin(), master_shape.end(), 1, std::multiplies<int>());

if (node_size == master_size) {
continue;
if (node_size != master_size) {
if (check_sync_mark(idx, master_data->id())) {
auto loops = ir_sch.GetLoops(master_data->id());
ir_sch.SyncThreads(loops.back(), false);
sync_mark.insert(master_data->id());
}
do_set_buffer_to_shared = true;
}
}

{
if (do_set_buffer_to_shared && group->output_nodes.find(node) == group->output_nodes.end()) {
auto block = ir_sch.GetBlock(node_data->id());
ir_sch.SetBuffer(block, "shared", true);
}

if (check_sync_mark(idx, master_data->id())) {
auto loops = ir_sch.GetLoops(master_data->id());
ir_sch.SyncThreads(loops.back(), false);
sync_mark.insert(master_data->id());
}
}
}

Expand Down
9 changes: 7 additions & 2 deletions cinn/hlir/framework/op_lowering_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ Node* FindNearestReducer(const Node* node, const std::unordered_set<Node*>& node
bool CanbeInline(Node* node,
const std::vector<Node*> consumers,
const Node* reducer,
const Node* laster,
const std::unordered_set<Node*> masters,
const GroupPtr& group,
const std::unordered_set<Node*>& nodes_set,
const absl::flat_hash_map<std::string, shape_t>& shape_dict);
Expand All @@ -72,6 +72,10 @@ Node* GetMasterToComputeAt(Node* node,
const std::unordered_map<Node*, Node*>& virtual_consumers,
const absl::flat_hash_map<std::string, shape_t>& shape_dict);

std::unordered_set<Node*> GetMasters(Node* node,
const std::unordered_set<Node*>& nodes_inline,
const std::unordered_set<Node*>& nodes_set);

void LoopAssignReduce(ir::IRSchedule& ir_sch,
const Node* node,
const Node* reducer,
Expand All @@ -90,7 +94,8 @@ void SyncThreadWithShared(ir::IRSchedule& ir_sch,
const std::unordered_set<Node*>& nodes_inline,
const std::unordered_set<Node*>& nodes_set,
const absl::flat_hash_map<std::string, shape_t>& shape_dict,
const std::unordered_map<std::string, ir::Tensor>& tensor_map);
const std::unordered_map<std::string, ir::Tensor>& tensor_map,
const GroupPtr& group);

} // namespace framework
} // namespace hlir
Expand Down
11 changes: 11 additions & 0 deletions cinn/hlir/pass/fusion_helper_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,17 @@ class FusionHelperBase {
return producer_node;
}

std::vector<Node*> GetConsumerNode(const Node* node) const {
std::vector<Node*> consumer_nodes;
auto node_data = GetNodeData(node);
for (auto& link : node_data->outlinks()) {
auto consumer = link->sink()->safe_as<Node>();
CHECK(consumer);
consumer_nodes.push_back(consumer);
}
return consumer_nodes;
}

bool WithoutLastDimInReduce(const std::vector<int>& inshape, const std::vector<int>& axes) const {
// if last axis is in reduce.
if (std::find(axes.begin(), axes.end(), inshape.size() - 1) != axes.end() ||
Expand Down
29 changes: 17 additions & 12 deletions cinn/hlir/pass/fusion_merge_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -617,10 +617,6 @@ class FusionMergePassHelper : public FusionHelperBase {

void RecomputeWithCostModel(const GroupPtr& producer,
std::unordered_set<GroupPtr, Hasher, Comparator>& fusionable_consumers) {
if (producer->op_pattern_kind == framework::kReduction) {
CHECK_EQ(fusionable_consumers.size(), 1) << "Find more than one consumer can fuse to " << producer->group_id;
}

// if is const op
if (is_const_group(this, producer)) {
std::unordered_set<GroupPtr, Hasher, Comparator> candidates;
Expand Down Expand Up @@ -818,14 +814,23 @@ class FusionMergePassHelper : public FusionHelperBase {
auto& consumers = input_consumers.second;
std::unordered_set<GroupPtr, Hasher, Comparator> updated_consumers;
for (auto& consumer : consumers) {
// if group is sub group
if (consumer->belong_groups.size()) {
// inset belong group to consumers.
for (auto& belong_group : consumer->belong_groups) {
updated_consumers.insert(belong_group);
std::queue<GroupPtr> fused_groups;
fused_groups.push(consumer);
while (!fused_groups.empty()) {
auto& cur = fused_groups.front();
fused_groups.pop();
// if group is sub group
if (cur->belong_groups.empty()) {
updated_consumers.insert(cur);
} else {
for (auto& belong_group : cur->belong_groups) {
if (belong_group->group_id == cur->group_id) {
updated_consumers.insert(belong_group);
} else {
fused_groups.push(belong_group);
}
}
}
} else {
updated_consumers.insert(consumer);
}
}
consumers = updated_consumers;
Expand Down Expand Up @@ -976,7 +981,7 @@ class FusionMergePassHelper : public FusionHelperBase {
relation.vertical_relation = {// reduce and elementwise can be horizontal/vertical relation.
{OpPatternKind::kElementWise, reduce_fuse_elementwise},
// reduce and broadcast op must be horizontal relation.
{OpPatternKind::kBroadcast, is_same_size},
{OpPatternKind::kBroadcast, reduce_fuse_broadcast},
// reduce and injective op must be horizontal relation.
{OpPatternKind::kInjective, horizontal_with_injective},
// reduce and reduce must be horizontal relation.
Expand Down
2 changes: 1 addition & 1 deletion cinn/hlir/pass/fusion_merge_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ TEST(FusionMergePass, Reduce_Test_2) {

auto graph = std::make_shared<hlir::framework::Graph>(program, target);
hlir::framework::ApplyPass(graph.get(), "OpFusionPass");
CHECK_EQ(graph->fusion_groups.size(), 4);
CHECK_EQ(graph->fusion_groups.size(), 3);
hlir::framework::ApplyPass(graph.get(), "FusionMergePass");
CHECK_EQ(graph->fusion_groups.size(), 2);
}
Expand Down
106 changes: 102 additions & 4 deletions cinn/hlir/pass/fusion_merge_pass_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -285,11 +285,109 @@ CONDITION_FUNC(injective_horizontal_with_reduce) {
return elementwise_fuse_reduce(helper, first, second);
}

CONDITION_FUNC(reduce_fuse_reduce) {
// check reduce horizontal with reduce.
if (!horizontal_relation(helper, first, second, framework::OpPatternKind::kReduction)) {
return false;
CONDITION_FUNC(reduce_fuse_broadcast) {
// if same shape with horizontal relation
if (is_same_size(helper, first, second)) {
return true;
}

// Traversing all reducers in all producers requires two types of conditions to be met.
// The first type is the condition that the reducer itself needs to meet,
// and the second type is the condition that the relationship between each reducer and its consumers with type of
// Broadcast needs to meet. It is required that each consumer of type Broadcast meet the same shape after broadcast as
// before reduce.
for (auto& node_in_master : first->master_nodes) {
if (helper->GetOpKind(node_in_master) != OpPatternKind::kReduction) {
continue;
}
Node* reducer = node_in_master;
// First type conditions
// Get some reduce infomation
auto reducer_input_shape = helper->GetNodeInputShape(reducer);
auto reducer_output_shape = helper->GetNodeDataShape(reducer);
auto reduce_axes = absl::get<std::vector<int>>(reducer->attrs.attr_store.at("dim"));
auto keep_dim = absl::get<bool>(reducer->attrs.attr_store.at("keep_dim"));
for (auto& axis : reduce_axes) {
if (axis == -1) {
axis = reducer_input_shape.size() - 1;
}
}
// Check if the reduce axes are continuous
int reduce_size = reducer_input_shape.back();
for (auto idx = reduce_axes.size() - 1; idx >= 1; --idx) {
if (reduce_axes[idx] != reduce_axes[idx - 1] + 1) {
return false;
}
reduce_size *= reducer_input_shape[idx - 1];
}
// Check if the reduce size exceeds the hardware limit
if (helper->target_ == common::DefaultNVGPUTarget() && reduce_size > helper->target_.max_num_threads()) {
return false;
}

// Second type conditions
// Find directly or indirectly consumers with type of Broadcast in the second group
auto find_broadcasters_in_descendants = [&](const Node* producer) -> std::unordered_set<const Node*> {
std::queue<const Node*> candidates;
std::unordered_set<const Node*> visited_set;
std::unordered_set<const Node*> broadcasters;
candidates.push(producer);

while (!candidates.empty()) {
auto candidate = candidates.front();
candidates.pop();

for (auto consumer : helper->GetConsumerNode(candidate)) {
if (helper->GetOpKind(consumer) == OpPatternKind::kBroadcast &&
second->NodeSet().find(consumer) != second->NodeSet().end()) {
broadcasters.insert(consumer);
} else if (!visited_set.count(consumer)) {
visited_set.insert(consumer);
candidates.push(consumer);
}
}
}

return broadcasters;
};

// Check if each broadcast node meets the conditions
std::unordered_set<const Node*> broadcasters_in_consumers = find_broadcasters_in_descendants(reducer);
for (auto broadcaster : broadcasters_in_consumers) {
auto broadcaster_output_shape = absl::get<std::vector<int>>(broadcaster->attrs.attr_store.at("out_shape"));
auto broadcast_axes = absl::get<std::vector<int>>(broadcaster->attrs.attr_store.at("broadcast_axes"));
for (auto& axis : broadcast_axes) {
if (axis == -1) {
axis = broadcaster_output_shape.size() - 1;
}
}

if (reducer_input_shape != broadcaster_output_shape) {
return false;
}

if (keep_dim) {
continue;
} else {
// if reducer_output_shape = [1]
if (reducer_output_shape.size() == 1 && reducer_output_shape[0] == 1) {
continue;
}
// check union [reduce_axes, broadcast_axes] = reducer_input_shape
for (int idx = 0; idx < reducer_input_shape.size(); ++idx) {
if (!(std::find(broadcast_axes.begin(), broadcast_axes.end(), idx) == broadcast_axes.end()) ^
std::find(reduce_axes.begin(), reduce_axes.end(), idx) == reduce_axes.end()) {
return false;
}
}
}
}
}

return true;
}

CONDITION_FUNC(reduce_fuse_reduce) {
if (!limit_args(helper, first, second)) {
return false;
}
Expand Down
2 changes: 1 addition & 1 deletion cinn/hlir/pass/op_fusion_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ class OpFusionPassHelper : public FusionHelperBase {
// producer -> fusion
relation.fusion_op_kind = {
// horizontal or vertical relation(Reduce + Elementwise*), check without last dimension in reduce.
{framework::kElementWise, without_last_dimension_in_reduce},
{framework::kElementWise, is_same_size},
// must be horizontal relation, check with same output shape and without last dimension in reduce.
{framework::kBroadcast, reduce_fuse_broadcast},
// must be horizontal relation and with same reduce attr.
Expand Down