Skip to content

Commit

Permalink
added chat template option
Browse files Browse the repository at this point in the history
  • Loading branch information
vidyasiv committed Jul 10, 2024
1 parent 78d78b2 commit 3aac1fb
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 5 deletions.
26 changes: 26 additions & 0 deletions examples/text-generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -544,3 +544,29 @@ deepspeed --num_gpus 8 run_lm_eval.py \
## Text-Generation Pipeline

A Transformers-like pipeline is defined and provided [here](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation/text-generation-pipeline). It is optimized for Gaudi and can be called to generate text in your scripts.

## Conversation generation

For models that support chat like `CohereForAI/c4ai-command-r-v01` you can provide `--chat_template <JSON FILE>` that is applied to the tokenizer.

### Examples

Sample chat template `sample_command_r_template.json` for [CohereForAI/c4ai-command-r-v01](https://huggingface.co/CohereForAI/c4ai-command-r-v01) is shown below:

```json
[{"role": "user", "content": "Hello, how are you?"}]
```

Command to run chat generation:

```
python run_generation.py \
--model_name_or_path CohereForAI/c4ai-command-r-v01 \
--use_hpu_graphs \
--use_kv_cache \
--max_new_tokens 100 \
--do_sample \
--chat_template sample_command_r_template.json \
--bf16 \
--batch_size 2
```
16 changes: 11 additions & 5 deletions examples/text-generation/run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,12 @@ def setup_parser(parser):
action="store_true",
help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
)
parser.add_argument(
"--chat_template",
default=None,
type=str,
help='Optional JSON input file containing chat template for tokenizer.',
)
args = parser.parse_args()

if args.torch_compile:
Expand Down Expand Up @@ -369,11 +375,11 @@ def assemble_prompt(prompt_size, book_path):
"Peace is the only way",
]

# Format message with the command-r chat template
if model.config.model_type == "cohere":
for i, sentence in enumerate(input_sentences):
message = [{"role": "user", "content": sentence}]
input_sentences[i] = tokenizer.apply_chat_template(message, tokenize=False)
# Apply tokenizer chat template
if args.chat_template and hasattr(tokenizer, 'chat_template'):
with open(args.chat_template, 'r') as fh:
messages = json.load(fh)
input_sentences = [tokenizer.apply_chat_template(messages, tokenize=False)]

if args.batch_size > len(input_sentences):
# Dynamically extends to support larger batch sizes
Expand Down
1 change: 1 addition & 0 deletions examples/text-generation/sample_command_r_template.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"role": "user", "content": "Hello, how are you?"}]
5 changes: 5 additions & 0 deletions tests/test_text_generation_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,11 @@ def _test_text_generation(
"--limit_hpu_graphs",
]

if "command_r" in model_name.lower():
path_to_template = os.path.join(
path_to_example_dir,"text-generation/sample_command_r_template.json")
command += [f"--chat_template {path_to_template}"]

with TemporaryDirectory() as tmp_dir:
command.append(f"--output_dir {tmp_dir}")
command.append(f"--token {token.value}")
Expand Down

0 comments on commit 3aac1fb

Please sign in to comment.