diff --git a/src/nnfusion/core/graph/gnode.cpp b/src/nnfusion/core/graph/gnode.cpp index ca6078711..f7b224041 100644 --- a/src/nnfusion/core/graph/gnode.cpp +++ b/src/nnfusion/core/graph/gnode.cpp @@ -416,6 +416,7 @@ void FusedGNode::set_inputs_and_outputs(std::shared_ptr graph) m_op_ctxs.push_back(ctx); } + std::unordered_map, std::unordered_map> input_id_map; // Register input tensors for (const auto& m_node : m_order_nodes) { @@ -430,6 +431,7 @@ void FusedGNode::set_inputs_and_outputs(std::shared_ptr graph) set_input(input_id, m_node->get_inputs().at(in_edge->get_dst_input())); graph->add_edge( in_edge->get_src(), in_edge->get_src_output(), shared_from_this(), input_id); + input_id_map[m_node][in_edge->get_dst_input()] = input_id; } } // Add control-edges as inputs of fused node @@ -461,6 +463,29 @@ void FusedGNode::set_inputs_and_outputs(std::shared_ptr graph) has_output = true; set_output(get_output_size(), m_node->get_outputs().at(out_edge->get_src_output())); + + // get inplace annotation + auto op = std::dynamic_pointer_cast(m_node->get_op_ptr()); + auto op_annotations = op->get_op_annotations(); + if (op_annotations) + { + auto oi_pairs = op_annotations->get_in_place_oi_pairs(); + for (auto oi_pair : oi_pairs) + { + auto iter = input_id_map.find(m_node); + if (iter != input_id_map.end() && iter->second.count(oi_pair.input) > 0) + { + auto fused_op = + std::dynamic_pointer_cast(shared_from_this()->get_op_ptr()); + AddInplace(fused_op, + get_output_size() - 1, + iter->second[oi_pair.input], + oi_pair.destructive, + oi_pair.force_inplace); + //NNFUSION_LOG(INFO) << "========================: node=" << m_node->get_op_type() << ", oi: <" << oi_pair.output << ", " << oi_pair.input << ">"; + } + } + } } graph->add_edge(shared_from_this(), get_output_size() - 1,