Skip to content

Commit

Permalink
fix kv_cache_on_host if statement and add non_blocking copy
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Zhentao <[email protected]>
  • Loading branch information
zhentaoyu committed Sep 12, 2024
1 parent cd58c34 commit 4b0fa1a
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ def pre_attn_forward(
else:
past_key_value = None

kv_cache_on_host = (key_states.device == "cpu" and value_states.device == "cpu")
kv_cache_on_host = (key_states.device == torch.device("cpu") and value_states.device == torch.device("cpu"))
# CPU SDPA fot next token
if kv_cache_on_host and q_len == 1 and not self.training:
query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv_cpu(
Expand All @@ -668,12 +668,12 @@ def pre_attn_forward(
dropout_p=0.0,
is_causal=False,
scale=self.norm_factor)
attn_output = attn_output.to("hpu")
attn_output = attn_output.to("hpu", non_blocking=True)

else:
if kv_cache_on_host:
key_states = key_states.to("hpu")
value_states = value_states.to("hpu")
key_states = key_states.to("hpu", non_blocking=True)
value_states = value_states.to("hpu", non_blocking=True)
if use_flash_attention and FusedSDPA is not None:
import habana_frameworks.torch.hpu as ht

Expand Down

0 comments on commit 4b0fa1a

Please sign in to comment.