From aee4795704cc02eeb0d813c198768f5a579b022d Mon Sep 17 00:00:00 2001 From: Yu Zhentao Date: Thu, 1 Aug 2024 02:53:39 +0000 Subject: [PATCH 1/4] cpu_kv and cpu_sdpa on llama Signed-off-by: Yu Zhentao --- examples/text-generation/run_generation.py | 10 ++ examples/text-generation/utils.py | 1 + .../generation/configuration_utils.py | 3 + .../habana/transformers/generation/utils.py | 16 ++- .../models/llama/modeling_llama.py | 112 +++++++++++++----- 5 files changed, 112 insertions(+), 30 deletions(-) diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index 0a16543c2a..59ee692d05 100755 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -231,6 +231,11 @@ def setup_parser(parser): action="store_true", help="Whether to reuse key/value cache for decoding. It should save memory.", ) + parser.add_argument( + "--kv_cache_on_host", + action="store_true", + help="Store key/value cache on CPU instead of HPU device (only support llama now). It should save vram on long context scenario.", + ) parser.add_argument("--verbose_workers", action="store_true", help="Enable output from non-master workers") parser.add_argument( "--simulate_dyn_prompt", @@ -328,6 +333,11 @@ def setup_parser(parser): logger.warning( "`--disk_offload` was tested only with fp8, it may not work with full precision. If error raises try to remove the --disk_offload flag." ) + if args.quant_config != "" and args.kv_cache_on_host: + logger.warning( + "`--kv_cache_on_host` is not supported with FP8 quantization. Set this flag to False." + ) + args.kv_cache_on_host = False return args diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index 5898b26671..51f91798eb 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -573,6 +573,7 @@ def setup_generation_config(args, model, assistant_model, tokenizer): generation_config.flash_attention_causal_mask = args.flash_attention_causal_mask generation_config.flash_attention_fast_softmax = args.flash_attention_fast_softmax generation_config.trust_remote_code = args.trust_remote_code + generation_config.kv_cache_on_host = args.kv_cache_on_host return generation_config diff --git a/optimum/habana/transformers/generation/configuration_utils.py b/optimum/habana/transformers/generation/configuration_utils.py index ce38a07ed9..852b3800db 100644 --- a/optimum/habana/transformers/generation/configuration_utils.py +++ b/optimum/habana/transformers/generation/configuration_utils.py @@ -37,6 +37,8 @@ class GaudiGenerationConfig(GenerationConfig): Whether to enable causal_mask if use Habana flash attention. flash_attention_fast_softmax_mode (`bool`, *optional*): Whether to use fast softmax with reduced precision if use Habana flash attention. + kv_cache_on_host (`bool`, *optional*): + Whether to store key/value cache on host (CPU). """ def __init__(self, **kwargs): @@ -55,3 +57,4 @@ def __init__(self, **kwargs): self.flash_attention_causal_mask = kwargs.get("flash_attention_causal_mask", None) self.flash_attention_fast_softmax = kwargs.get("flash_attention_fast_softmax", None) self.use_fused_rope = kwargs.get("use_fused_rope", None) + self.kv_cache_on_host = kwargs.get("kv_cache_on_host", False) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index d4baf44c06..41040234d2 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -905,6 +905,10 @@ def generate( ), "please set bucket_internal along with reuse_cache and bucket_size" else: assert generation_config.bucket_size >= 0, "please set valid bucket_size to use bucket_internal" + if generation_config.kv_cache_on_host: + assert self.config.model_type in [ + "llama", + ], "kv_cache_on_host only supported by llama at the moment" if generation_config.static_shapes: # Pad inputs to have static shapes during generation, this gives better performance than dynamic shapes on HPUs @@ -1089,10 +1093,18 @@ def generate( calculated_max_length = input_ids.shape[-1] + generation_config.max_new_tokens + num_virtual_tokens if generation_config.use_cache and generation_config.reuse_cache: bs, _ = input_ids.shape + cache_device = "cpu" if generation_config.kv_cache_on_host else "hpu" if not is_greedy_or_beam_and_bucket: - unwrap_deepspeed_model(self).allocate_kv_cache( - bs * generation_config.num_beams, calculated_max_length, token_idx + num_virtual_tokens + if self.config.model_type in ["llama"]: + print("Allocate KV Cache on CPU...") + unwrap_deepspeed_model(self).allocate_kv_cache( + bs * generation_config.num_beams, calculated_max_length, token_idx + num_virtual_tokens, + device=cache_device ) + else: + unwrap_deepspeed_model(self).allocate_kv_cache( + bs * generation_config.num_beams, calculated_max_length, token_idx + num_virtual_tokens + ) if generation_config.use_cache: model_kwargs["kv_cache_len"] = calculated_max_length model_kwargs["kv_cache_pad_len"] = generation_config.max_new_tokens diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 1abbfab12d..56aa08fe62 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -326,6 +326,9 @@ def gaudi_llama_repeat_kv( The query states go from (batch, num_heads, seqlen, head_dim) to (batch, num_key_value_heads, n_rep, seqlen, head_dim) The key/value states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_key_value_heads, 1, seqlen, head_dim) """ + query_states = query_states.to("hpu") + key_states = key_states.to("hpu") + value_states = value_states.to("hpu") batch, num_key_value_heads, kv_len, head_dim = key_states.shape if n_rep == 1 or num_key_value_heads == 1: return query_states, key_states, value_states, attention_mask @@ -344,6 +347,41 @@ def gaudi_llama_repeat_kv( return query_states, key_states, value_states, attention_mask +def gaudi_llama_repeat_kv_cpu( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: torch.Tensor, + n_rep: int, +): + """ + PyTorch SDPA CPU (flash-atten) kernel does not support GQA/MQA for now. + So, expand k and v to num_query_heads + """ + query_states = query_states.to("cpu") + key_states = key_states.to("cpu") + value_states = value_states.to("cpu") + if attention_mask is not None: + attention_mask = attention_mask.to("cpu") + + batch, num_key_value_heads, kv_len, head_dim = key_states.shape + if n_rep == 1 or num_key_value_heads == 1: + return query_states, key_states, value_states, attention_mask + + key_states = key_states[:, :, None, :, :].expand(batch, + num_key_value_heads, + n_rep, + kv_len, + head_dim) + value_states = value_states[:, :, None, :, :].expand(batch, + num_key_value_heads, + n_rep, + kv_len, + head_dim) + key_states = key_states.reshape(batch, num_key_value_heads * n_rep, kv_len, head_dim) + value_states = value_states.reshape(batch, num_key_value_heads * n_rep, kv_len, head_dim) + + return query_states, key_states, value_states, attention_mask # FusedScaledDotProductAttention class ModuleFusedSDPA(torch.nn.Module): @@ -381,6 +419,8 @@ def allocate(self, inp_seq_len, dtype, device, shape): @staticmethod def update(prev, cur, dim, idx, inp_seq_len): + cur = cur.to(prev.device) + idx = idx.to(prev.device) orig_cur = cur if prev.shape == cur.shape: prev.copy_(cur) @@ -442,9 +482,8 @@ def get_k_proj_weight_dtype(self): return self.k_proj.scales.dtype return self.k_proj.weight.dtype - def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, device="hpu"): cache_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim) - device = self.get_k_proj_weight().device dtype = self.config.torch_dtype self.k_cache.allocate(inp_seq_len, dtype, device, cache_shape) self.v_cache.allocate(inp_seq_len, dtype, device, cache_shape) @@ -617,7 +656,8 @@ def pre_attn_forward( else: past_key_value = None - if use_flash_attention and FusedSDPA is not None: + bool kv_cache_on_host = (key_states.device() == "cpu" and value_states.device() == "cpu") + if use_flash_attention and FusedSDPA is not None and not kv_cache_on_host: import habana_frameworks.torch.hpu as ht softmax_mode = "fast" if flash_attention_fast_softmax else "None" @@ -644,28 +684,44 @@ def pre_attn_forward( ) else: - query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv( - query_states, key_states, value_states, attention_mask, self.num_key_value_groups - ) + if q_len == 1 and kv_cache_on_host: + # CPU SDPA fot next token + query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv_cpu( + query_states, key_states, value_states, attention_mask, self.num_key_value_groups + ) + # pytorch https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html + # dispatch to flash attention implementation + attn_output = F.scaled_dot_product_attention(query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + scale=self.norm_factor) + attn_output = attn_output.to("hpu") + else: + query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv( + query_states, key_states, value_states, attention_mask, self.num_key_value_groups + ) - attn_weights = self.matmul_qk(query_states, key_states.transpose(-2, -1)) * self.norm_factor + attn_weights = self.matmul_qk(query_states, key_states.transpose(-2, -1)) * self.norm_factor - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask - if cache_position is not None: - causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask + if cache_position is not None: + causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask - if attn_softmax_bf16: - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=query_states.dtype) - else: - # upcast attention to fp32 - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( - query_states.dtype - ) - attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = self.matmul_av(attn_weights, value_states) - attn_output = attn_output.reshape(bsz, -1, q_len, self.head_dim) + if attn_softmax_bf16: + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=query_states.dtype) + else: + # upcast attention to fp32 + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + query_states.dtype + ) + attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = self.matmul_av(attn_weights, value_states) + attn_output = attn_output.reshape(bsz, -1, q_len, self.head_dim) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( @@ -811,8 +867,8 @@ def __init__(self, config: LlamaConfig, layer_idx: int): self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): - self.self_attn.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, device="hpu"): + self.self_attn.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len, device=device) def reorder_kv_cache(self, beam_idx: torch.LongTensor): return self.self_attn.reorder_kv_cache(beam_idx) @@ -988,9 +1044,9 @@ def __init__(self, config: LlamaConfig): # Initialize weights and apply final processing self.post_init() - def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, device="hpu"): for layer in self.layers: - layer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) + layer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len, device=device) def reorder_kv_cache(self, beam_idx: torch.LongTensor): return tuple(layer.reorder_kv_cache(beam_idx) for layer in self.layers) @@ -1223,8 +1279,8 @@ def __init__(self, config, parallel_strategy: DistributedStrategy = NoOpStrategy config.parallel_strategy = parallel_strategy super().__init__(config) - def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): - self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, device="hpu"): + self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len, device=device) def reorder_kv_cache(self, beam_idx: torch.LongTensor): return self.model.reorder_kv_cache(beam_idx) From 1b4ee20fdf591976746cf617461d2a63946321b2 Mon Sep 17 00:00:00 2001 From: Yu Zhentao Date: Thu, 1 Aug 2024 08:38:52 +0000 Subject: [PATCH 2/4] refact code and add README Signed-off-by: Yu Zhentao --- examples/text-generation/README.md | 21 +++++ examples/text-generation/run_generation.py | 5 ++ .../habana/transformers/generation/utils.py | 2 +- .../models/llama/modeling_llama.py | 83 ++++++++++--------- 4 files changed, 69 insertions(+), 42 deletions(-) diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index b720936ff4..360d1fe07a 100755 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -539,6 +539,27 @@ python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_generation.py \ For more details see [documentation](https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Optimization_in_PyTorch_Models.html#using-fused-sdpa). +### Store KV Cache on CPU +Keeping key/value cache on CPU (host) side can decrease hpu vram in spite of it may damage generation latency. It's a practical solution in long context serving scenario with a large LLM on single card. + +You can add `--kv_cache_on_host` arg to enable it. [Pytorch SDPA operator](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) will be automatically used to generate next token for saving data transfer time. First token is not be affected. + +For exmaple: +```bash +python run_generation.py \ +--model_name_or_path meta-llama/Llama-2-7b-hf \ +--use_kv_cache \ +--bf16 \ +--attn_softmax_bf16 \ +--max_new_tokens 128 \ +--reuse_cache \ +--do_sample \ +--prompt "Here is my prompt" +--kv_cache_on_host +``` + +> [!NOTE] +> `--kv_cache_on_host` only supports llama model for now. And it can not work with `--use_hpu_grapgs` and FP8 data type. ## Language Model Evaluation Harness diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index 59ee692d05..713bef08a5 100755 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -338,6 +338,11 @@ def setup_parser(parser): "`--kv_cache_on_host` is not supported with FP8 quantization. Set this flag to False." ) args.kv_cache_on_host = False + if args.kv_cache_on_host and args.use_hpu_graphs: + logger.warning( + "`--kv_cache_on_host` is not supported with HPU graphs. Set this flag to False." + ) + args.kv_cache_on_host = False return args diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 41040234d2..6f1dcdd886 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -1095,7 +1095,7 @@ def generate( bs, _ = input_ids.shape cache_device = "cpu" if generation_config.kv_cache_on_host else "hpu" if not is_greedy_or_beam_and_bucket: - if self.config.model_type in ["llama"]: + if generation_config.kv_cache_on_host and self.config.model_type in ["llama"]: print("Allocate KV Cache on CPU...") unwrap_deepspeed_model(self).allocate_kv_cache( bs * generation_config.num_beams, calculated_max_length, token_idx + num_virtual_tokens, diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 56aa08fe62..a57b2b04fb 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -326,9 +326,6 @@ def gaudi_llama_repeat_kv( The query states go from (batch, num_heads, seqlen, head_dim) to (batch, num_key_value_heads, n_rep, seqlen, head_dim) The key/value states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_key_value_heads, 1, seqlen, head_dim) """ - query_states = query_states.to("hpu") - key_states = key_states.to("hpu") - value_states = value_states.to("hpu") batch, num_key_value_heads, kv_len, head_dim = key_states.shape if n_rep == 1 or num_key_value_heads == 1: return query_states, key_states, value_states, attention_mask @@ -656,49 +653,53 @@ def pre_attn_forward( else: past_key_value = None - bool kv_cache_on_host = (key_states.device() == "cpu" and value_states.device() == "cpu") - if use_flash_attention and FusedSDPA is not None and not kv_cache_on_host: - import habana_frameworks.torch.hpu as ht - - softmax_mode = "fast" if flash_attention_fast_softmax else "None" + kv_cache_on_host = (key_states.device == "cpu" and value_states.device == "cpu") + # CPU SDPA fot next token + if kv_cache_on_host and q_len == 1 and not self.training: + query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv_cpu( + query_states, key_states, value_states, attention_mask, self.num_key_value_groups + ) + # pytorch https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html + # dispatch to flash attention implementation + attn_output = F.scaled_dot_product_attention(query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + scale=self.norm_factor) + attn_output = attn_output.to("hpu") - if q_len == 1: - # next token - use_recompute = True if os.getenv("QUANT_CONFIG", "") else False - with ht.sdp_kernel(enable_recompute=use_recompute): - attn_output = self.fused_scaled_dot_product_attention( - query_states, key_states, value_states, attention_mask, 0.0, False, None, "None" - ) - else: - # first token - if flash_attention_causal_mask: - # causal masking on first token requires inputs to be of the same length - with ht.sdp_kernel(enable_recompute=flash_attention_recompute): - attn_output = self.fused_scaled_dot_product_attention( - query_states, key_states, value_states, None, 0.0, True, None, softmax_mode - ) - else: - with ht.sdp_kernel(enable_recompute=flash_attention_recompute): + else: + if kv_cache_on_host: + key_states = key_states.to("hpu") + value_states = value_states.to("hpu") + if use_flash_attention and FusedSDPA is not None: + import habana_frameworks.torch.hpu as ht + + softmax_mode = "fast" if flash_attention_fast_softmax else "None" + + if q_len == 1: + # next token + use_recompute = True if os.getenv("QUANT_CONFIG", "") else False + with ht.sdp_kernel(enable_recompute=use_recompute): attn_output = self.fused_scaled_dot_product_attention( query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode ) + else: + # first token + if flash_attention_causal_mask: + # causal masking on first token requires inputs to be of the same length + with ht.sdp_kernel(enable_recompute=flash_attention_recompute): + attn_output = self.fused_scaled_dot_product_attention( + query_states, key_states, value_states, None, 0.0, True, None, softmax_mode + ) + else: + with ht.sdp_kernel(enable_recompute=flash_attention_recompute): + attn_output = self.fused_scaled_dot_product_attention( + query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode + ) - else: - if q_len == 1 and kv_cache_on_host: - # CPU SDPA fot next token - query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv_cpu( - query_states, key_states, value_states, attention_mask, self.num_key_value_groups - ) - # pytorch https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html - # dispatch to flash attention implementation - attn_output = F.scaled_dot_product_attention(query_states, - key_states, - value_states, - attn_mask=attention_mask, - dropout_p=0.0, - is_causal=False, - scale=self.norm_factor) - attn_output = attn_output.to("hpu") else: query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv( query_states, key_states, value_states, attention_mask, self.num_key_value_groups From fd29d4e65b26784801174bbc3bcf6b3562721a83 Mon Sep 17 00:00:00 2001 From: Yu Zhentao Date: Thu, 12 Sep 2024 06:14:16 +0000 Subject: [PATCH 3/4] fix kv_cache_on_host if statement and add non_blocking copy Signed-off-by: Yu Zhentao --- .../habana/transformers/models/llama/modeling_llama.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index a57b2b04fb..78871174d7 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -653,7 +653,7 @@ def pre_attn_forward( else: past_key_value = None - kv_cache_on_host = (key_states.device == "cpu" and value_states.device == "cpu") + kv_cache_on_host = (key_states.device == torch.device("cpu") and value_states.device == torch.device("cpu")) # CPU SDPA fot next token if kv_cache_on_host and q_len == 1 and not self.training: query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv_cpu( @@ -668,12 +668,12 @@ def pre_attn_forward( dropout_p=0.0, is_causal=False, scale=self.norm_factor) - attn_output = attn_output.to("hpu") + attn_output = attn_output.to("hpu", non_blocking=True) else: if kv_cache_on_host: - key_states = key_states.to("hpu") - value_states = value_states.to("hpu") + key_states = key_states.to("hpu", non_blocking=True) + value_states = value_states.to("hpu", non_blocking=True) if use_flash_attention and FusedSDPA is not None: import habana_frameworks.torch.hpu as ht From 74e94ff520fefcf0b5109bf9a7d5acfec6f1347d Mon Sep 17 00:00:00 2001 From: Yu Zhentao Date: Wed, 18 Sep 2024 08:10:31 +0000 Subject: [PATCH 4/4] add long-context example in README Signed-off-by: Yu Zhentao --- examples/text-generation/README.md | 15 +++++++++++---- optimum/habana/transformers/generation/utils.py | 2 +- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index 360d1fe07a..e593b42801 100755 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -547,19 +547,26 @@ You can add `--kv_cache_on_host` arg to enable it. [Pytorch SDPA operator](https For exmaple: ```bash python run_generation.py \ ---model_name_or_path meta-llama/Llama-2-7b-hf \ +--model_name_or_path 01-ai/Yi-34B-Chat \ --use_kv_cache \ --bf16 \ --attn_softmax_bf16 \ ---max_new_tokens 128 \ --reuse_cache \ --do_sample \ ---prompt "Here is my prompt" +--dataset_name emozilla/pg19-test \ +--batch_size 1 \ +--max_input_tokens 11200 \ +--column_name "text" \ +--dataset_max_samples 1 \ +--warmup 0 \ +--n_iterations 1 \ +--max_new_tokens 5000 \ --kv_cache_on_host ``` > [!NOTE] -> `--kv_cache_on_host` only supports llama model for now. And it can not work with `--use_hpu_grapgs` and FP8 data type. +> 1. `--kv_cache_on_host` only supports llama model for now. And it can not work with `--use_hpu_grapgs` and FP8 data type. +> 2. Try to use it when you only meet HPU workspace allocation error (`OOM`) since it will increase latency. ## Language Model Evaluation Harness diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 6f1dcdd886..ee26d54dbc 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -1093,10 +1093,10 @@ def generate( calculated_max_length = input_ids.shape[-1] + generation_config.max_new_tokens + num_virtual_tokens if generation_config.use_cache and generation_config.reuse_cache: bs, _ = input_ids.shape - cache_device = "cpu" if generation_config.kv_cache_on_host else "hpu" if not is_greedy_or_beam_and_bucket: if generation_config.kv_cache_on_host and self.config.model_type in ["llama"]: print("Allocate KV Cache on CPU...") + cache_device = "cpu" unwrap_deepspeed_model(self).allocate_kv_cache( bs * generation_config.num_beams, calculated_max_length, token_idx + num_virtual_tokens, device=cache_device