diff --git a/src/plugins/intel_gpu/include/intel_gpu/graph/network.hpp b/src/plugins/intel_gpu/include/intel_gpu/graph/network.hpp index ff63909d049a5a..eaffaad2281b01 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/graph/network.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/graph/network.hpp @@ -193,6 +193,8 @@ struct network { const ov::intel_gpu::VariableStateInfo& get_variable_info(const std::string &variable_id) const; const ov::intel_gpu::VariablesMap& get_variables() const; const ov::intel_gpu::VariablesInfoMap& get_variables_info() const; + void set_reuse_variable_mem(bool reuse = false); + bool is_reuse_variable_mem() { return _reuse_variable_mem; } const ExecutionConfig& get_config() const { return _config; } @@ -216,6 +218,7 @@ struct network { bool _is_dynamic = false; bool _enable_profiling = false; bool _reset_arguments; + bool _reuse_variable_mem = false; std::unordered_map> _primitives; std::vector _in_out_shared_mem_types; diff --git a/src/plugins/intel_gpu/src/graph/network.cpp b/src/plugins/intel_gpu/src/graph/network.cpp index e6cbd0dc15728f..eef58068f1ab7e 100644 --- a/src/plugins/intel_gpu/src/graph/network.cpp +++ b/src/plugins/intel_gpu/src/graph/network.cpp @@ -1028,5 +1028,9 @@ void network::set_variables_state_info(const std::string& variable_id, _variables_state_info.at(variable_id).m_primitives.insert(p); } +void network::set_reuse_variable_mem(bool reuse) { + _reuse_variable_mem = reuse; +} + } // namespace cldnn diff --git a/src/plugins/intel_gpu/src/graph/primitive_inst.cpp b/src/plugins/intel_gpu/src/graph/primitive_inst.cpp index ce7803978a05f4..480de4803f2e5c 100644 --- a/src/plugins/intel_gpu/src/graph/primitive_inst.cpp +++ b/src/plugins/intel_gpu/src/graph/primitive_inst.cpp @@ -624,16 +624,24 @@ void primitive_inst::realloc_if_needed(bool prev_execution_skipped) { _max_output_layout_count[j] = 0; } } else { - _outputs[0] = variable.get_memory(); + GPU_DEBUG_TRACE_DETAIL + << id() << " : realloc_if_needed: can_be_optimized = false and memories are not being shared" + << std::endl; + if (!get_network().is_reuse_variable_mem()) { + GPU_DEBUG_TRACE_DETAIL << "Update output mem with new variable mem" << std::endl; + _outputs[0] = variable.get_memory(); + _max_output_layout_count[0] = variable.get_actual_mem_size() / dt_sizes_in_B[0]; - if (auto compressed_cache_variable = dynamic_cast(&variable)) { - _outputs[2] = compressed_cache_variable->get_compression_scale_state()->get_memory(); + if (auto compressed_cache_variable = dynamic_cast(&variable)) { + _outputs[2] = compressed_cache_variable->get_compression_scale_state()->get_memory(); - if (compressed_cache_variable->has_zp_state()) { - _outputs[3] = compressed_cache_variable->get_compression_zp_state()->get_memory(); + if (compressed_cache_variable->has_zp_state()) { + _outputs[3] = compressed_cache_variable->get_compression_zp_state()->get_memory(); + } } + } else { + GPU_DEBUG_TRACE_DETAIL << "Can reuse variable mem of prev request" << std::endl; } - GPU_DEBUG_TRACE_DETAIL << id() << " : realloc_if_needed: can_be_optimized = false and memories are not being shared" << std::endl; } } else { variable.set_layout(_impl_params->output_layouts[0]); diff --git a/src/plugins/intel_gpu/src/plugin/sync_infer_request.cpp b/src/plugins/intel_gpu/src/plugin/sync_infer_request.cpp index f87f9af5275722..676e37294c818d 100644 --- a/src/plugins/intel_gpu/src/plugin/sync_infer_request.cpp +++ b/src/plugins/intel_gpu/src/plugin/sync_infer_request.cpp @@ -295,13 +295,21 @@ void SyncInferRequest::enqueue() { std::move(events.begin(), events.end(), std::back_inserter(dependencies)); } + auto network = m_graph->get_network(); for (const auto& it : m_variables) { const auto& name = it.first; const auto& variable = it.second; + if (network->has_variable(name)) { + const auto& prev_var = network->get_variable(name); + if (prev_var.get_memory() == variable->get_memory()) { + network->set_reuse_variable_mem(true); + continue; + } + } + network->set_reuse_variable_mem(false); prepare_state(name, variable); } - auto network = m_graph->get_network(); network->set_shape_predictor(m_shape_predictor); m_internal_outputs.clear();