diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index b720936ff..e593b4280 100755 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -539,6 +539,34 @@ python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_generation.py \ For more details see [documentation](https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Optimization_in_PyTorch_Models.html#using-fused-sdpa). +### Store KV Cache on CPU +Keeping key/value cache on CPU (host) side can decrease hpu vram in spite of it may damage generation latency. It's a practical solution in long context serving scenario with a large LLM on single card. + +You can add `--kv_cache_on_host` arg to enable it. [Pytorch SDPA operator](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) will be automatically used to generate next token for saving data transfer time. First token is not be affected. + +For exmaple: +```bash +python run_generation.py \ +--model_name_or_path 01-ai/Yi-34B-Chat \ +--use_kv_cache \ +--bf16 \ +--attn_softmax_bf16 \ +--reuse_cache \ +--do_sample \ +--dataset_name emozilla/pg19-test \ +--batch_size 1 \ +--max_input_tokens 11200 \ +--column_name "text" \ +--dataset_max_samples 1 \ +--warmup 0 \ +--n_iterations 1 \ +--max_new_tokens 5000 \ +--kv_cache_on_host +``` + +> [!NOTE] +> 1. `--kv_cache_on_host` only supports llama model for now. And it can not work with `--use_hpu_grapgs` and FP8 data type. +> 2. Try to use it when you only meet HPU workspace allocation error (`OOM`) since it will increase latency. ## Language Model Evaluation Harness diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index 0a16543c2..713bef08a 100755 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -231,6 +231,11 @@ def setup_parser(parser): action="store_true", help="Whether to reuse key/value cache for decoding. It should save memory.", ) + parser.add_argument( + "--kv_cache_on_host", + action="store_true", + help="Store key/value cache on CPU instead of HPU device (only support llama now). It should save vram on long context scenario.", + ) parser.add_argument("--verbose_workers", action="store_true", help="Enable output from non-master workers") parser.add_argument( "--simulate_dyn_prompt", @@ -328,6 +333,16 @@ def setup_parser(parser): logger.warning( "`--disk_offload` was tested only with fp8, it may not work with full precision. If error raises try to remove the --disk_offload flag." ) + if args.quant_config != "" and args.kv_cache_on_host: + logger.warning( + "`--kv_cache_on_host` is not supported with FP8 quantization. Set this flag to False." + ) + args.kv_cache_on_host = False + if args.kv_cache_on_host and args.use_hpu_graphs: + logger.warning( + "`--kv_cache_on_host` is not supported with HPU graphs. Set this flag to False." + ) + args.kv_cache_on_host = False return args diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index 5898b2667..51f91798e 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -573,6 +573,7 @@ def setup_generation_config(args, model, assistant_model, tokenizer): generation_config.flash_attention_causal_mask = args.flash_attention_causal_mask generation_config.flash_attention_fast_softmax = args.flash_attention_fast_softmax generation_config.trust_remote_code = args.trust_remote_code + generation_config.kv_cache_on_host = args.kv_cache_on_host return generation_config diff --git a/optimum/habana/transformers/generation/configuration_utils.py b/optimum/habana/transformers/generation/configuration_utils.py index ce38a07ed..852b3800d 100644 --- a/optimum/habana/transformers/generation/configuration_utils.py +++ b/optimum/habana/transformers/generation/configuration_utils.py @@ -37,6 +37,8 @@ class GaudiGenerationConfig(GenerationConfig): Whether to enable causal_mask if use Habana flash attention. flash_attention_fast_softmax_mode (`bool`, *optional*): Whether to use fast softmax with reduced precision if use Habana flash attention. + kv_cache_on_host (`bool`, *optional*): + Whether to store key/value cache on host (CPU). """ def __init__(self, **kwargs): @@ -55,3 +57,4 @@ def __init__(self, **kwargs): self.flash_attention_causal_mask = kwargs.get("flash_attention_causal_mask", None) self.flash_attention_fast_softmax = kwargs.get("flash_attention_fast_softmax", None) self.use_fused_rope = kwargs.get("use_fused_rope", None) + self.kv_cache_on_host = kwargs.get("kv_cache_on_host", False) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index d4baf44c0..ee26d54db 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -905,6 +905,10 @@ def generate( ), "please set bucket_internal along with reuse_cache and bucket_size" else: assert generation_config.bucket_size >= 0, "please set valid bucket_size to use bucket_internal" + if generation_config.kv_cache_on_host: + assert self.config.model_type in [ + "llama", + ], "kv_cache_on_host only supported by llama at the moment" if generation_config.static_shapes: # Pad inputs to have static shapes during generation, this gives better performance than dynamic shapes on HPUs @@ -1090,9 +1094,17 @@ def generate( if generation_config.use_cache and generation_config.reuse_cache: bs, _ = input_ids.shape if not is_greedy_or_beam_and_bucket: - unwrap_deepspeed_model(self).allocate_kv_cache( - bs * generation_config.num_beams, calculated_max_length, token_idx + num_virtual_tokens + if generation_config.kv_cache_on_host and self.config.model_type in ["llama"]: + print("Allocate KV Cache on CPU...") + cache_device = "cpu" + unwrap_deepspeed_model(self).allocate_kv_cache( + bs * generation_config.num_beams, calculated_max_length, token_idx + num_virtual_tokens, + device=cache_device ) + else: + unwrap_deepspeed_model(self).allocate_kv_cache( + bs * generation_config.num_beams, calculated_max_length, token_idx + num_virtual_tokens + ) if generation_config.use_cache: model_kwargs["kv_cache_len"] = calculated_max_length model_kwargs["kv_cache_pad_len"] = generation_config.max_new_tokens diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 1abbfab12..78871174d 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -344,6 +344,41 @@ def gaudi_llama_repeat_kv( return query_states, key_states, value_states, attention_mask +def gaudi_llama_repeat_kv_cpu( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: torch.Tensor, + n_rep: int, +): + """ + PyTorch SDPA CPU (flash-atten) kernel does not support GQA/MQA for now. + So, expand k and v to num_query_heads + """ + query_states = query_states.to("cpu") + key_states = key_states.to("cpu") + value_states = value_states.to("cpu") + if attention_mask is not None: + attention_mask = attention_mask.to("cpu") + + batch, num_key_value_heads, kv_len, head_dim = key_states.shape + if n_rep == 1 or num_key_value_heads == 1: + return query_states, key_states, value_states, attention_mask + + key_states = key_states[:, :, None, :, :].expand(batch, + num_key_value_heads, + n_rep, + kv_len, + head_dim) + value_states = value_states[:, :, None, :, :].expand(batch, + num_key_value_heads, + n_rep, + kv_len, + head_dim) + key_states = key_states.reshape(batch, num_key_value_heads * n_rep, kv_len, head_dim) + value_states = value_states.reshape(batch, num_key_value_heads * n_rep, kv_len, head_dim) + + return query_states, key_states, value_states, attention_mask # FusedScaledDotProductAttention class ModuleFusedSDPA(torch.nn.Module): @@ -381,6 +416,8 @@ def allocate(self, inp_seq_len, dtype, device, shape): @staticmethod def update(prev, cur, dim, idx, inp_seq_len): + cur = cur.to(prev.device) + idx = idx.to(prev.device) orig_cur = cur if prev.shape == cur.shape: prev.copy_(cur) @@ -442,9 +479,8 @@ def get_k_proj_weight_dtype(self): return self.k_proj.scales.dtype return self.k_proj.weight.dtype - def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, device="hpu"): cache_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim) - device = self.get_k_proj_weight().device dtype = self.config.torch_dtype self.k_cache.allocate(inp_seq_len, dtype, device, cache_shape) self.v_cache.allocate(inp_seq_len, dtype, device, cache_shape) @@ -617,55 +653,76 @@ def pre_attn_forward( else: past_key_value = None - if use_flash_attention and FusedSDPA is not None: - import habana_frameworks.torch.hpu as ht - - softmax_mode = "fast" if flash_attention_fast_softmax else "None" + kv_cache_on_host = (key_states.device == torch.device("cpu") and value_states.device == torch.device("cpu")) + # CPU SDPA fot next token + if kv_cache_on_host and q_len == 1 and not self.training: + query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv_cpu( + query_states, key_states, value_states, attention_mask, self.num_key_value_groups + ) + # pytorch https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html + # dispatch to flash attention implementation + attn_output = F.scaled_dot_product_attention(query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + scale=self.norm_factor) + attn_output = attn_output.to("hpu", non_blocking=True) - if q_len == 1: - # next token - use_recompute = True if os.getenv("QUANT_CONFIG", "") else False - with ht.sdp_kernel(enable_recompute=use_recompute): - attn_output = self.fused_scaled_dot_product_attention( - query_states, key_states, value_states, attention_mask, 0.0, False, None, "None" - ) - else: - # first token - if flash_attention_causal_mask: - # causal masking on first token requires inputs to be of the same length - with ht.sdp_kernel(enable_recompute=flash_attention_recompute): - attn_output = self.fused_scaled_dot_product_attention( - query_states, key_states, value_states, None, 0.0, True, None, softmax_mode - ) - else: - with ht.sdp_kernel(enable_recompute=flash_attention_recompute): + else: + if kv_cache_on_host: + key_states = key_states.to("hpu", non_blocking=True) + value_states = value_states.to("hpu", non_blocking=True) + if use_flash_attention and FusedSDPA is not None: + import habana_frameworks.torch.hpu as ht + + softmax_mode = "fast" if flash_attention_fast_softmax else "None" + + if q_len == 1: + # next token + use_recompute = True if os.getenv("QUANT_CONFIG", "") else False + with ht.sdp_kernel(enable_recompute=use_recompute): attn_output = self.fused_scaled_dot_product_attention( query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode ) + else: + # first token + if flash_attention_causal_mask: + # causal masking on first token requires inputs to be of the same length + with ht.sdp_kernel(enable_recompute=flash_attention_recompute): + attn_output = self.fused_scaled_dot_product_attention( + query_states, key_states, value_states, None, 0.0, True, None, softmax_mode + ) + else: + with ht.sdp_kernel(enable_recompute=flash_attention_recompute): + attn_output = self.fused_scaled_dot_product_attention( + query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode + ) - else: - query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv( - query_states, key_states, value_states, attention_mask, self.num_key_value_groups - ) + else: + query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv( + query_states, key_states, value_states, attention_mask, self.num_key_value_groups + ) - attn_weights = self.matmul_qk(query_states, key_states.transpose(-2, -1)) * self.norm_factor + attn_weights = self.matmul_qk(query_states, key_states.transpose(-2, -1)) * self.norm_factor - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask - if cache_position is not None: - causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask + if cache_position is not None: + causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask - if attn_softmax_bf16: - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=query_states.dtype) - else: - # upcast attention to fp32 - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( - query_states.dtype - ) - attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = self.matmul_av(attn_weights, value_states) - attn_output = attn_output.reshape(bsz, -1, q_len, self.head_dim) + if attn_softmax_bf16: + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=query_states.dtype) + else: + # upcast attention to fp32 + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + query_states.dtype + ) + attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = self.matmul_av(attn_weights, value_states) + attn_output = attn_output.reshape(bsz, -1, q_len, self.head_dim) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( @@ -811,8 +868,8 @@ def __init__(self, config: LlamaConfig, layer_idx: int): self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): - self.self_attn.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, device="hpu"): + self.self_attn.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len, device=device) def reorder_kv_cache(self, beam_idx: torch.LongTensor): return self.self_attn.reorder_kv_cache(beam_idx) @@ -988,9 +1045,9 @@ def __init__(self, config: LlamaConfig): # Initialize weights and apply final processing self.post_init() - def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, device="hpu"): for layer in self.layers: - layer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) + layer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len, device=device) def reorder_kv_cache(self, beam_idx: torch.LongTensor): return tuple(layer.reorder_kv_cache(beam_idx) for layer in self.layers) @@ -1223,8 +1280,8 @@ def __init__(self, config, parallel_strategy: DistributedStrategy = NoOpStrategy config.parallel_strategy = parallel_strategy super().__init__(config) - def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): - self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, device="hpu"): + self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len, device=device) def reorder_kv_cache(self, beam_idx: torch.LongTensor): return self.model.reorder_kv_cache(beam_idx)