Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for cohere command-r and chat models #1031

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
26 changes: 26 additions & 0 deletions examples/text-generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -670,3 +670,29 @@ deepspeed --num_gpus 8 run_lm_eval.py \
## Text-Generation Pipeline

A Transformers-like pipeline is defined and provided [here](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation/text-generation-pipeline). It is optimized for Gaudi and can be called to generate text in your scripts.

## Conversation generation

For models that support chat like `CohereForAI/c4ai-command-r-v01` you can provide `--conversation_input <JSON FILE>` that is applied to the tokenizer.

### Examples

Sample conversation `sample_command_r_conversation.json` for [CohereForAI/c4ai-command-r-v01](https://huggingface.co/CohereForAI/c4ai-command-r-v01) is shown below:

```json
[{"role": "user", "content": "Hello, how are you?"}]
```

Command to run chat generation:

```
python run_generation.py \
--model_name_or_path CohereForAI/c4ai-command-r-v01 \
--use_hpu_graphs \
--use_kv_cache \
--max_new_tokens 100 \
--do_sample \
--conversation_input sample_command_r_conversation.json \
--bf16 \
--batch_size 2
```
18 changes: 16 additions & 2 deletions examples/text-generation/run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import logging
import math
import os
import sys
import time
from itertools import cycle
from pathlib import Path
Expand Down Expand Up @@ -339,6 +340,12 @@ def setup_parser(parser):
help="Path to neural-compressor quantized model, if set, the checkpoint will be loaded.",
)

parser.add_argument(
"--conversation_input",
default=None,
type=str,
help="Optional JSON input file containing conversation input.",
)
args = parser.parse_args()

if args.torch_compile:
Expand Down Expand Up @@ -406,8 +413,6 @@ def download_book(book_id):
return save_path
else:
print("Failed to download book! Exiting...")
import sys

sys.exit()

def assemble_prompt(prompt_size, book_path):
Expand All @@ -427,6 +432,14 @@ def assemble_prompt(prompt_size, book_path):
1342, # Pride and Prejudice
]
input_sentences = assemble_prompt(prompt_size=args.max_input_tokens, book_path=download_book(book_ids[0]))
elif args.conversation_input and hasattr(tokenizer, "chat_template"):
with open(args.conversation_input, "r") as fh:
messages = json.load(fh)
try:
input_sentences = [tokenizer.apply_chat_template(conversation=messages, tokenize=False)]
except Exception as e:
logger.error(f"Error applying chat template to tokenizer: {e}")
sys.exit()
else:
input_sentences = [
"DeepSpeed is a machine learning framework",
Expand All @@ -439,6 +452,7 @@ def assemble_prompt(prompt_size, book_path):
"Peace is the only way",
]


vidyasiv marked this conversation as resolved.
Show resolved Hide resolved
if args.batch_size > len(input_sentences):
# Dynamically extends to support larger batch sizes
num_sentences_to_add = args.batch_size - len(input_sentences)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"role": "user", "content": "Hello, how are you?"}]
1 change: 1 addition & 0 deletions examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,7 @@ def setup_generation_config(args, model, assistant_model, tokenizer):
generation_config.flash_attention_causal_mask = args.flash_attention_causal_mask
generation_config.flash_attention_fast_softmax = args.flash_attention_fast_softmax
generation_config.trust_remote_code = args.trust_remote_code
generation_config.temperature = args.temperature
generation_config.valid_sequence_lengths = None

return generation_config
Expand Down
7 changes: 6 additions & 1 deletion tests/test_text_generation_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@
("adept/persimmon-8b-base", 4, False, 366.73968820698406),
("Qwen/Qwen1.5-7B", 4, False, 490.8621617893209),
("google/gemma-7b", 1, False, 109.70751574382221),
("state-spaces/mamba-130m-hf", 1536, False, 5385.511100161605),
("CohereForAI/c4ai-command-r-v01", 1, False, 30.472430202916325),
("state-spaces/mamba-130m-hf", 1536, False, 8600),
("Deci/DeciLM-7B", 1, False, 120),
("Qwen/Qwen2-7B", 512, False, 9669.45787),
("Qwen/Qwen1.5-MoE-A2.7B", 1, True, 44.25834541569395),
Expand Down Expand Up @@ -255,6 +256,10 @@ def _test_text_generation(
f"--parallel_strategy={parallel_strategy}",
]

if "command_r" in model_name.lower():
path_to_conv = os.path.join(path_to_example_dir, "text-generation/sample_command_r_conversation.json")
command += [f"--conversation_input {path_to_conv}"]

with TemporaryDirectory() as tmp_dir:
command.append(f"--output_dir {tmp_dir}")
command.append(f"--token {token.value}")
Expand Down