Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[llama] Store KV Cache on CPU and Use PyTorch SPDA for Next token generation #1182

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions examples/text-generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,34 @@ python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_generation.py \

For more details see [documentation](https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Optimization_in_PyTorch_Models.html#using-fused-sdpa).

### Store KV Cache on CPU
Keeping key/value cache on CPU (host) side can decrease hpu vram in spite of it may damage generation latency. It's a practical solution in long context serving scenario with a large LLM on single card.

You can add `--kv_cache_on_host` arg to enable it. [Pytorch SDPA operator](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) will be automatically used to generate next token for saving data transfer time. First token is not be affected.

For exmaple:
```bash
python run_generation.py \
--model_name_or_path 01-ai/Yi-34B-Chat \
--use_kv_cache \
--bf16 \
--attn_softmax_bf16 \
--reuse_cache \
--do_sample \
--dataset_name emozilla/pg19-test \
--batch_size 1 \
--max_input_tokens 11200 \
--column_name "text" \
--dataset_max_samples 1 \
--warmup 0 \
--n_iterations 1 \
--max_new_tokens 5000 \
--kv_cache_on_host
```

> [!NOTE]
> 1. `--kv_cache_on_host` only supports llama model for now. And it can not work with `--use_hpu_grapgs` and FP8 data type.
> 2. Try to use it when you only meet HPU workspace allocation error (`OOM`) since it will increase latency.

## Language Model Evaluation Harness

Expand Down
15 changes: 15 additions & 0 deletions examples/text-generation/run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,11 @@ def setup_parser(parser):
action="store_true",
help="Whether to reuse key/value cache for decoding. It should save memory.",
)
parser.add_argument(
"--kv_cache_on_host",
action="store_true",
help="Store key/value cache on CPU instead of HPU device (only support llama now). It should save vram on long context scenario.",
)
parser.add_argument("--verbose_workers", action="store_true", help="Enable output from non-master workers")
parser.add_argument(
"--simulate_dyn_prompt",
Expand Down Expand Up @@ -328,6 +333,16 @@ def setup_parser(parser):
logger.warning(
"`--disk_offload` was tested only with fp8, it may not work with full precision. If error raises try to remove the --disk_offload flag."
)
if args.quant_config != "" and args.kv_cache_on_host:
logger.warning(
"`--kv_cache_on_host` is not supported with FP8 quantization. Set this flag to False."
)
args.kv_cache_on_host = False
if args.kv_cache_on_host and args.use_hpu_graphs:
logger.warning(
"`--kv_cache_on_host` is not supported with HPU graphs. Set this flag to False."
)
args.kv_cache_on_host = False
return args


Expand Down
1 change: 1 addition & 0 deletions examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,7 @@ def setup_generation_config(args, model, assistant_model, tokenizer):
generation_config.flash_attention_causal_mask = args.flash_attention_causal_mask
generation_config.flash_attention_fast_softmax = args.flash_attention_fast_softmax
generation_config.trust_remote_code = args.trust_remote_code
generation_config.kv_cache_on_host = args.kv_cache_on_host

return generation_config

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class GaudiGenerationConfig(GenerationConfig):
Whether to enable causal_mask if use Habana flash attention.
flash_attention_fast_softmax_mode (`bool`, *optional*):
Whether to use fast softmax with reduced precision if use Habana flash attention.
kv_cache_on_host (`bool`, *optional*):
Whether to store key/value cache on host (CPU).
"""

def __init__(self, **kwargs):
Expand All @@ -55,3 +57,4 @@ def __init__(self, **kwargs):
self.flash_attention_causal_mask = kwargs.get("flash_attention_causal_mask", None)
self.flash_attention_fast_softmax = kwargs.get("flash_attention_fast_softmax", None)
self.use_fused_rope = kwargs.get("use_fused_rope", None)
self.kv_cache_on_host = kwargs.get("kv_cache_on_host", False)
16 changes: 14 additions & 2 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,6 +905,10 @@ def generate(
), "please set bucket_internal along with reuse_cache and bucket_size"
else:
assert generation_config.bucket_size >= 0, "please set valid bucket_size to use bucket_internal"
if generation_config.kv_cache_on_host:
assert self.config.model_type in [
"llama",
], "kv_cache_on_host only supported by llama at the moment"

if generation_config.static_shapes:
# Pad inputs to have static shapes during generation, this gives better performance than dynamic shapes on HPUs
Expand Down Expand Up @@ -1090,9 +1094,17 @@ def generate(
if generation_config.use_cache and generation_config.reuse_cache:
bs, _ = input_ids.shape
if not is_greedy_or_beam_and_bucket:
unwrap_deepspeed_model(self).allocate_kv_cache(
bs * generation_config.num_beams, calculated_max_length, token_idx + num_virtual_tokens
if generation_config.kv_cache_on_host and self.config.model_type in ["llama"]:
print("Allocate KV Cache on CPU...")
cache_device = "cpu"
unwrap_deepspeed_model(self).allocate_kv_cache(
bs * generation_config.num_beams, calculated_max_length, token_idx + num_virtual_tokens,
device=cache_device
)
else:
unwrap_deepspeed_model(self).allocate_kv_cache(
bs * generation_config.num_beams, calculated_max_length, token_idx + num_virtual_tokens
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From line 1096 to 1107, I would like to suggest to change like this.

if not is_greedy_or_beam_and_bucket:
cache_device = "hpu"
if generation_config.kv_cache_on_host and self.config.model_type in ["llama"]:
print("Allocate KV Cache on CPU...")
cache_device = "cpu"
unwrap_deepspeed_model(self).allocate_kv_cache(
bs * generation_config.num_beams, calculated_max_length, token_idx + num_virtual_tokens,
device=cache_device
)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I have updated it in 74e94ff. However, I can not remove the else line because I only modified the modeling_llama.py for this experimental feature.

if generation_config.use_cache:
model_kwargs["kv_cache_len"] = calculated_max_length
model_kwargs["kv_cache_pad_len"] = generation_config.max_new_tokens
Expand Down
155 changes: 106 additions & 49 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,41 @@ def gaudi_llama_repeat_kv(

return query_states, key_states, value_states, attention_mask

def gaudi_llama_repeat_kv_cpu(
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
attention_mask: torch.Tensor,
n_rep: int,
):
"""
PyTorch SDPA CPU (flash-atten) kernel does not support GQA/MQA for now.
So, expand k and v to num_query_heads
"""
query_states = query_states.to("cpu")
key_states = key_states.to("cpu")
value_states = value_states.to("cpu")
if attention_mask is not None:
attention_mask = attention_mask.to("cpu")

batch, num_key_value_heads, kv_len, head_dim = key_states.shape
if n_rep == 1 or num_key_value_heads == 1:
return query_states, key_states, value_states, attention_mask

key_states = key_states[:, :, None, :, :].expand(batch,
num_key_value_heads,
n_rep,
kv_len,
head_dim)
value_states = value_states[:, :, None, :, :].expand(batch,
num_key_value_heads,
n_rep,
kv_len,
head_dim)
key_states = key_states.reshape(batch, num_key_value_heads * n_rep, kv_len, head_dim)
value_states = value_states.reshape(batch, num_key_value_heads * n_rep, kv_len, head_dim)

return query_states, key_states, value_states, attention_mask

# FusedScaledDotProductAttention
class ModuleFusedSDPA(torch.nn.Module):
Expand Down Expand Up @@ -381,6 +416,8 @@ def allocate(self, inp_seq_len, dtype, device, shape):

@staticmethod
def update(prev, cur, dim, idx, inp_seq_len):
cur = cur.to(prev.device)
idx = idx.to(prev.device)
orig_cur = cur
if prev.shape == cur.shape:
prev.copy_(cur)
Expand Down Expand Up @@ -442,9 +479,8 @@ def get_k_proj_weight_dtype(self):
return self.k_proj.scales.dtype
return self.k_proj.weight.dtype

def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len):
def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, device="hpu"):
cache_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim)
device = self.get_k_proj_weight().device
dtype = self.config.torch_dtype
self.k_cache.allocate(inp_seq_len, dtype, device, cache_shape)
self.v_cache.allocate(inp_seq_len, dtype, device, cache_shape)
Expand Down Expand Up @@ -617,55 +653,76 @@ def pre_attn_forward(
else:
past_key_value = None

if use_flash_attention and FusedSDPA is not None:
import habana_frameworks.torch.hpu as ht

softmax_mode = "fast" if flash_attention_fast_softmax else "None"
kv_cache_on_host = (key_states.device == torch.device("cpu") and value_states.device == torch.device("cpu"))
# CPU SDPA fot next token
if kv_cache_on_host and q_len == 1 and not self.training:
query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv_cpu(
query_states, key_states, value_states, attention_mask, self.num_key_value_groups
)
# pytorch https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
# dispatch to flash attention implementation
attn_output = F.scaled_dot_product_attention(query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
scale=self.norm_factor)
attn_output = attn_output.to("hpu", non_blocking=True)

if q_len == 1:
# next token
use_recompute = True if os.getenv("QUANT_CONFIG", "") else False
with ht.sdp_kernel(enable_recompute=use_recompute):
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, attention_mask, 0.0, False, None, "None"
)
else:
# first token
if flash_attention_causal_mask:
# causal masking on first token requires inputs to be of the same length
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, None, 0.0, True, None, softmax_mode
)
else:
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
else:
if kv_cache_on_host:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please explain what's the case switching kv_cache device? I thought line 656 is the case only when line 658.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this pr, we make kv cache store on cpu and do cpu sdpa only when generating the next token. The first token or prefill stage is performed on HPU due to its powerful computation ability under long-context scenario (long prompt in most cases). The full pipeline diagram shows on the pr description.
So line 658 tells the machine it can do pytorch-cpu sdpa (flash-attn) only when kv_cache_on_host & in next-token generation & inference stage. Otherwise, it will transfer the kv-cache to hpu device if need for its original operations.
Please let me know if you need more explanations or have some suggestions. Thanks.

key_states = key_states.to("hpu", non_blocking=True)
value_states = value_states.to("hpu", non_blocking=True)
if use_flash_attention and FusedSDPA is not None:
import habana_frameworks.torch.hpu as ht

softmax_mode = "fast" if flash_attention_fast_softmax else "None"

if q_len == 1:
# next token
use_recompute = True if os.getenv("QUANT_CONFIG", "") else False
with ht.sdp_kernel(enable_recompute=use_recompute):
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode
)
else:
# first token
if flash_attention_causal_mask:
# causal masking on first token requires inputs to be of the same length
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, None, 0.0, True, None, softmax_mode
)
else:
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode
)

else:
query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv(
query_states, key_states, value_states, attention_mask, self.num_key_value_groups
)
else:
query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv(
query_states, key_states, value_states, attention_mask, self.num_key_value_groups
)

attn_weights = self.matmul_qk(query_states, key_states.transpose(-2, -1)) * self.norm_factor
attn_weights = self.matmul_qk(query_states, key_states.transpose(-2, -1)) * self.norm_factor

if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask
if cache_position is not None:
causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask
if cache_position is not None:
causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask

if attn_softmax_bf16:
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=query_states.dtype)
else:
# upcast attention to fp32
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
query_states.dtype
)
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = self.matmul_av(attn_weights, value_states)
attn_output = attn_output.reshape(bsz, -1, q_len, self.head_dim)
if attn_softmax_bf16:
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=query_states.dtype)
else:
# upcast attention to fp32
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
query_states.dtype
)
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = self.matmul_av(attn_weights, value_states)
attn_output = attn_output.reshape(bsz, -1, q_len, self.head_dim)

if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
Expand Down Expand Up @@ -811,8 +868,8 @@ def __init__(self, config: LlamaConfig, layer_idx: int):
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len):
self.self_attn.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len)
def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, device="hpu"):
self.self_attn.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len, device=device)

def reorder_kv_cache(self, beam_idx: torch.LongTensor):
return self.self_attn.reorder_kv_cache(beam_idx)
Expand Down Expand Up @@ -988,9 +1045,9 @@ def __init__(self, config: LlamaConfig):
# Initialize weights and apply final processing
self.post_init()

def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len):
def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, device="hpu"):
for layer in self.layers:
layer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len)
layer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len, device=device)

def reorder_kv_cache(self, beam_idx: torch.LongTensor):
return tuple(layer.reorder_kv_cache(beam_idx) for layer in self.layers)
Expand Down Expand Up @@ -1223,8 +1280,8 @@ def __init__(self, config, parallel_strategy: DistributedStrategy = NoOpStrategy
config.parallel_strategy = parallel_strategy
super().__init__(config)

def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len):
self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len)
def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, device="hpu"):
self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len, device=device)

def reorder_kv_cache(self, beam_idx: torch.LongTensor):
return self.model.reorder_kv_cache(beam_idx)
Expand Down