Skip to content

Commit

Permalink
Add --use_kv_cache to image-to-text pipeline (#1292)
Browse files Browse the repository at this point in the history
  • Loading branch information
KimBioInfoStudio authored Sep 14, 2024
1 parent f87f0fb commit 520c875
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 8 deletions.
20 changes: 20 additions & 0 deletions examples/image-to-text/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ Models that have been validated:
- [llava-hf/llava-v1.6-mistral-7b-hf](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf)
- [llava-hf/llava-v1.6-vicuna-7b-hf](https://huggingface.co/llava-hf/llava-v1.6-vicuna-7b-hf)
- [llava-hf/llava-v1.6-vicuna-13b-hf](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf)
- [llava-hf/llava-v1.6-34b-hf](https://huggingface.co/llava-hf/llava-v1.6-34b-hf)
- [llava-hf/llama3-llava-next-8b-hf](https://huggingface.co/llava-hf/llama3-llava-next-8b-hf)

### Inference with BF16

Expand Down Expand Up @@ -72,6 +74,24 @@ python3 run_pipeline.py \
--bf16
```

To run Llava-hf/llava-v1.6-34b-hf inference, use the following command:

```bash
python3 run_pipeline.py \
--model_name_or_path llava-hf/llava-v1.6-34b-hf \
--use_hpu_graphs \
--bf16
```

To run Llava-hf/llama3-llava-next-8b-hf inference, use the following command:

```bash
python3 run_pipeline.py \
--model_name_or_path llava-hf/llama3-llava-next-8b-hf \
--use_hpu_graphs \
--bf16
```

### Inference with FP8
Inference for Llava-1.5-7b, Llava-1.5-13b, Llava-v1.6-mistral-7b and Llava-v1.6-vicuna-13b in FP8 precision are enabled using [Intel Neural Compressor (INC)](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_FP8.html), which provides model measurement and quantization capabilities in PyTorch. INC is used by default for measuring and quantization. Habana Quantization Toolkit (HQT), which was used earlier, will be removed in future releases. To use HQT, disable INC by setting the following environment variable: `USE_INC=0`.

Expand Down
32 changes: 24 additions & 8 deletions examples/image-to-text/run_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import PIL.Image
import requests
import torch
from transformers import AutoConfig, pipeline
from transformers import AutoConfig, LlavaNextProcessor, LlavaProcessor, pipeline

from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi

Expand Down Expand Up @@ -141,6 +141,11 @@ def main():
action="store_true",
help="Whether to enable Habana Flash Attention in recompute mode on first token generation. This gives an opportunity of splitting graph internally which helps reduce memory consumption.",
)
parser.add_argument(
"--use_kv_cache",
action="store_true",
help="Whether to use the key/value cache for decoding. It should speed up generation.",
)

args = parser.parse_args()

Expand All @@ -156,12 +161,21 @@ def main():
args.image_path = [
"https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"
]
if args.prompt is None and model_type == "llava":
args.prompt = "<image>\nUSER: What's the content of the image?\nASSISTANT:"
elif args.prompt is None and model_type == "llava_next":
args.prompt = "[INST] <image>\nWhat is shown in this image? [/INST]"
if args.model_name_or_path in ["llava-hf/llava-v1.6-vicuna-13b-hf", "llava-hf/llava-v1.6-vicuna-7b-hf"]:
args.prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nWhat is shown in this image? ASSISTANT:"
if args.prompt is None:
if model_type == "llava":
processor = LlavaProcessor.from_pretrained(args.model_name_or_path)
elif model_type == "llava_next":
processor = LlavaNextProcessor.from_pretrained(args.model_name_or_path)
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": "What is shown in this image?"},
{"type": "image"},
],
}
]
args.prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)

image_paths = args.image_path
image_paths_len = len(image_paths)
Expand Down Expand Up @@ -197,6 +211,7 @@ def main():
)
generate_kwargs = {
"lazy_mode": True,
"use_cache": args.use_kv_cache,
"hpu_graphs": args.use_hpu_graphs,
"max_new_tokens": args.max_new_tokens,
"ignore_eos": args.ignore_eos,
Expand Down Expand Up @@ -233,8 +248,9 @@ def main():

total_new_tokens_generated = args.n_iterations * n_output_tokens
throughput = total_new_tokens_generated / duration
logger.info(f"result = {result}")
logger.info(
f"result = {result}, time = {(end-start) * 1000 / args.n_iterations }ms, Throughput (including tokenization) = {throughput} tokens/second"
f"time = {(end-start) * 1000 / args.n_iterations }ms, Throughput (including tokenization) = {throughput} tokens/second"
)

# Store results if necessary
Expand Down

0 comments on commit 520c875

Please sign in to comment.