Skip to content

Commit

Permalink
add long-context example in README
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Zhentao <[email protected]>
  • Loading branch information
zhentaoyu committed Sep 18, 2024
1 parent fd29d4e commit 74e94ff
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
15 changes: 11 additions & 4 deletions examples/text-generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -547,19 +547,26 @@ You can add `--kv_cache_on_host` arg to enable it. [Pytorch SDPA operator](https
For exmaple:
```bash
python run_generation.py \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--model_name_or_path 01-ai/Yi-34B-Chat \
--use_kv_cache \
--bf16 \
--attn_softmax_bf16 \
--max_new_tokens 128 \
--reuse_cache \
--do_sample \
--prompt "Here is my prompt"
--dataset_name emozilla/pg19-test \
--batch_size 1 \
--max_input_tokens 11200 \
--column_name "text" \
--dataset_max_samples 1 \
--warmup 0 \
--n_iterations 1 \
--max_new_tokens 5000 \
--kv_cache_on_host
```

> [!NOTE]
> `--kv_cache_on_host` only supports llama model for now. And it can not work with `--use_hpu_grapgs` and FP8 data type.
> 1. `--kv_cache_on_host` only supports llama model for now. And it can not work with `--use_hpu_grapgs` and FP8 data type.
> 2. Try to use it when you only meet HPU workspace allocation error (`OOM`) since it will increase latency.
## Language Model Evaluation Harness

Expand Down
2 changes: 1 addition & 1 deletion optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,10 +1093,10 @@ def generate(
calculated_max_length = input_ids.shape[-1] + generation_config.max_new_tokens + num_virtual_tokens
if generation_config.use_cache and generation_config.reuse_cache:
bs, _ = input_ids.shape
cache_device = "cpu" if generation_config.kv_cache_on_host else "hpu"
if not is_greedy_or_beam_and_bucket:
if generation_config.kv_cache_on_host and self.config.model_type in ["llama"]:
print("Allocate KV Cache on CPU...")
cache_device = "cpu"
unwrap_deepspeed_model(self).allocate_kv_cache(
bs * generation_config.num_beams, calculated_max_length, token_idx + num_virtual_tokens,
device=cache_device
Expand Down

0 comments on commit 74e94ff

Please sign in to comment.