diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index fad94bdbcb..a0949c8694 100755 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -525,6 +525,27 @@ python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_generation.py \ For more details see [documentation](https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Optimization_in_PyTorch_Models.html#using-fused-sdpa). +### Store KV Cache on CPU +Keeping key/value cache on CPU (host) side can decrease hpu vram in spite of it may damage generation latency. It's a practical solution in long context serving scenario with a large LLM on single card. + +You can add `--kv_cache_on_host` arg to enable it. [Pytorch SDPA operator](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) will be automatically used to generate next token for saving data transfer time. First token is not be affected. + +For exmaple: +```bash +python run_generation.py \ +--model_name_or_path meta-llama/Llama-2-7b-hf \ +--use_kv_cache \ +--bf16 \ +--attn_softmax_bf16 \ +--max_new_tokens 128 \ +--reuse_cache \ +--do_sample \ +--prompt "Here is my prompt" +--kv_cache_on_host +``` + +> [!NOTE] +> `--kv_cache_on_host` only supports llama model for now. And it can not work with `--use_hpu_grapgs` and FP8 data type. ## Language Model Evaluation Harness diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index 290b7e7273..738ce5b3f6 100755 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -318,6 +318,11 @@ def setup_parser(parser): "`--kv_cache_on_host` is not supported with FP8 quantization. Set this flag to False." ) args.kv_cache_on_host = False + if args.kv_cache_on_host and args.use_hpu_graphs: + logger.warning( + "`--kv_cache_on_host` is not supported with HPU graphs. Set this flag to False." + ) + args.kv_cache_on_host = False return args diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 2bfad489c8..f5488a6ced 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -1014,7 +1014,7 @@ def generate( bs, _ = input_ids.shape cache_device = "cpu" if generation_config.kv_cache_on_host else "hpu" if not is_greedy_or_beam_and_bucket: - if self.config.model_type in ["llama"]: + if generation_config.kv_cache_on_host and self.config.model_type in ["llama"]: print("Allocate KV Cache on CPU...") unwrap_deepspeed_model(self).allocate_kv_cache( bs * generation_config.num_beams, calculated_max_length, token_idx + num_virtual_tokens, diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 509be77ff3..035bbbe7ca 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -246,9 +246,6 @@ def gaudi_llama_repeat_kv( The query states go from (batch, num_heads, seqlen, head_dim) to (batch, num_key_value_heads, n_rep, seqlen, head_dim) The key/value states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_key_value_heads, 1, seqlen, head_dim) """ - query_states = query_states.to("hpu") - key_states = key_states.to("hpu") - value_states = value_states.to("hpu") batch, num_key_value_heads, kv_len, head_dim = key_states.shape if n_rep == 1 or num_key_value_heads == 1: return query_states, key_states, value_states, attention_mask @@ -545,49 +542,53 @@ def pre_attn_forward( else: past_key_value = None - bool kv_cache_on_host = (key_states.device() == "cpu" and value_states.device() == "cpu") - if use_flash_attention and FusedSDPA and not kv_cache_on_host: - import habana_frameworks.torch.hpu as ht - - softmax_mode = "fast" if flash_attention_fast_softmax else "None" + kv_cache_on_host = (key_states.device == "cpu" and value_states.device == "cpu") + # CPU SDPA fot next token + if kv_cache_on_host and q_len == 1 and not self.training: + query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv_cpu( + query_states, key_states, value_states, attention_mask, self.num_key_value_groups + ) + # pytorch https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html + # dispatch to flash attention implementation + attn_output = F.scaled_dot_product_attention(query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + scale=self.norm_factor) + attn_output = attn_output.to("hpu") - if q_len == 1: - # next token - use_recompute = True if os.getenv("QUANT_CONFIG", "") else False - with ht.sdp_kernel(enable_recompute=use_recompute): - attn_output = self.fused_scaled_dot_product_attention( - query_states, key_states, value_states, attention_mask, 0.0, False, None, "None" - ) - else: - # first token - if flash_attention_causal_mask: - # causal masking on first token requires inputs to be of the same length - with ht.sdp_kernel(enable_recompute=flash_attention_recompute): - attn_output = self.fused_scaled_dot_product_attention( - query_states, key_states, value_states, None, 0.0, True, None, softmax_mode - ) - else: - with ht.sdp_kernel(enable_recompute=flash_attention_recompute): + else: + if kv_cache_on_host: + key_states = key_states.to("hpu") + value_states = value_states.to("hpu") + if use_flash_attention and FusedSDPA: + import habana_frameworks.torch.hpu as ht + + softmax_mode = "fast" if flash_attention_fast_softmax else "None" + + if q_len == 1: + # next token + use_recompute = True if os.getenv("QUANT_CONFIG", "") else False + with ht.sdp_kernel(enable_recompute=use_recompute): attn_output = self.fused_scaled_dot_product_attention( query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode ) + else: + # first token + if flash_attention_causal_mask: + # causal masking on first token requires inputs to be of the same length + with ht.sdp_kernel(enable_recompute=flash_attention_recompute): + attn_output = self.fused_scaled_dot_product_attention( + query_states, key_states, value_states, None, 0.0, True, None, softmax_mode + ) + else: + with ht.sdp_kernel(enable_recompute=flash_attention_recompute): + attn_output = self.fused_scaled_dot_product_attention( + query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode + ) - else: - if q_len == 1 and kv_cache_on_host: - # CPU SDPA fot next token - query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv_cpu( - query_states, key_states, value_states, attention_mask, self.num_key_value_groups - ) - # pytorch https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html - # dispatch to flash attention implementation - attn_output = F.scaled_dot_product_attention(query_states, - key_states, - value_states, - attn_mask=attention_mask, - dropout_p=0.0, - is_causal=False, - scale=self.norm_factor) - attn_output = attn_output.to("hpu") else: query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv( query_states, key_states, value_states, attention_mask, self.num_key_value_groups