Skip to content

Commit

Permalink
[GPU] Reuse kv cache mem if it is not changed from previous infer (op…
Browse files Browse the repository at this point in the history
…envinotoolkit#28361)

### Details:
 - When kv cache variable is reset, it is allocating a new memory.
- However, if the variable mem is not changed from previous iteration,
we can reuse previsouly allocated memory

### Tickets:
 - *ticket-id*
  • Loading branch information
yeonbok authored Jan 10, 2025
1 parent f616896 commit a8dfb18
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 7 deletions.
3 changes: 3 additions & 0 deletions src/plugins/intel_gpu/include/intel_gpu/graph/network.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,8 @@ struct network {
const ov::intel_gpu::VariableStateInfo& get_variable_info(const std::string &variable_id) const;
const ov::intel_gpu::VariablesMap& get_variables() const;
const ov::intel_gpu::VariablesInfoMap& get_variables_info() const;
void set_reuse_variable_mem(bool reuse = false);
bool is_reuse_variable_mem() { return _reuse_variable_mem; }

const ExecutionConfig& get_config() const { return _config; }

Expand All @@ -216,6 +218,7 @@ struct network {
bool _is_dynamic = false;
bool _enable_profiling = false;
bool _reset_arguments;
bool _reuse_variable_mem = false;

std::unordered_map<primitive_id, std::shared_ptr<primitive_inst>> _primitives;
std::vector<shared_mem_type> _in_out_shared_mem_types;
Expand Down
4 changes: 4 additions & 0 deletions src/plugins/intel_gpu/src/graph/network.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1028,5 +1028,9 @@ void network::set_variables_state_info(const std::string& variable_id,
_variables_state_info.at(variable_id).m_primitives.insert(p);
}

void network::set_reuse_variable_mem(bool reuse) {
_reuse_variable_mem = reuse;
}


} // namespace cldnn
20 changes: 14 additions & 6 deletions src/plugins/intel_gpu/src/graph/primitive_inst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -624,16 +624,24 @@ void primitive_inst::realloc_if_needed(bool prev_execution_skipped) {
_max_output_layout_count[j] = 0;
}
} else {
_outputs[0] = variable.get_memory();
GPU_DEBUG_TRACE_DETAIL
<< id() << " : realloc_if_needed: can_be_optimized = false and memories are not being shared"
<< std::endl;
if (!get_network().is_reuse_variable_mem()) {
GPU_DEBUG_TRACE_DETAIL << "Update output mem with new variable mem" << std::endl;
_outputs[0] = variable.get_memory();
_max_output_layout_count[0] = variable.get_actual_mem_size() / dt_sizes_in_B[0];

if (auto compressed_cache_variable = dynamic_cast<ov::intel_gpu::VariableStateIndirectKVCacheCompressed*>(&variable)) {
_outputs[2] = compressed_cache_variable->get_compression_scale_state()->get_memory();
if (auto compressed_cache_variable = dynamic_cast<ov::intel_gpu::VariableStateIndirectKVCacheCompressed*>(&variable)) {
_outputs[2] = compressed_cache_variable->get_compression_scale_state()->get_memory();

if (compressed_cache_variable->has_zp_state()) {
_outputs[3] = compressed_cache_variable->get_compression_zp_state()->get_memory();
if (compressed_cache_variable->has_zp_state()) {
_outputs[3] = compressed_cache_variable->get_compression_zp_state()->get_memory();
}
}
} else {
GPU_DEBUG_TRACE_DETAIL << "Can reuse variable mem of prev request" << std::endl;
}
GPU_DEBUG_TRACE_DETAIL << id() << " : realloc_if_needed: can_be_optimized = false and memories are not being shared" << std::endl;
}
} else {
variable.set_layout(_impl_params->output_layouts[0]);
Expand Down
10 changes: 9 additions & 1 deletion src/plugins/intel_gpu/src/plugin/sync_infer_request.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -295,13 +295,21 @@ void SyncInferRequest::enqueue() {
std::move(events.begin(), events.end(), std::back_inserter(dependencies));
}

auto network = m_graph->get_network();
for (const auto& it : m_variables) {
const auto& name = it.first;
const auto& variable = it.second;
if (network->has_variable(name)) {
const auto& prev_var = network->get_variable(name);
if (prev_var.get_memory() == variable->get_memory()) {
network->set_reuse_variable_mem(true);
continue;
}
}
network->set_reuse_variable_mem(false);
prepare_state(name, variable);
}

auto network = m_graph->get_network();
network->set_shape_predictor(m_shape_predictor);

m_internal_outputs.clear();
Expand Down

0 comments on commit a8dfb18

Please sign in to comment.