From 84c9b7bd058919c0646e11da8dda11e8aa138af5 Mon Sep 17 00:00:00 2001 From: Vidya S Galli Date: Tue, 2 Apr 2024 00:36:26 +0000 Subject: [PATCH 01/14] initial commit --- examples/text-generation/run_generation.py | 3 +++ optimum/habana/transformers/generation/utils.py | 1 - 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index 0a16543c2a..c948287399 100755 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -385,6 +385,9 @@ def assemble_prompt(prompt_size, book_path): 1342, # Pride and Prejudice ] input_sentences = assemble_prompt(prompt_size=args.max_input_tokens, book_path=download_book(book_ids[0])) + elif model_config.model_type == "cohere": + messages = [{"role": "user", "content": args.prompt}] + input_sentences = tokenizer.apply_chat_template(messages, tokenize=False, return_tensors="pt") else: input_sentences = [ "DeepSpeed is a machine learning framework", diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index d333986679..cbab696699 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -1103,7 +1103,6 @@ def generate( # 7. determine generation mode generation_mode = generation_config.get_generation_mode(assistant_model) - if generation_config.bucket_size > 0: assert generation_config.static_shapes, "bucket_size > 0 can be set only when static_shapes is set" # if generation_config.bucket_size <= 0, padding is handled by the generating fn (like greedy_search) From e68cadc3cbb2e8f317a09ae00319760cce5841e5 Mon Sep 17 00:00:00 2001 From: Soila Kavulya Date: Tue, 2 Apr 2024 13:07:29 -0700 Subject: [PATCH 02/14] Add StoppingCriteriaList for C4AI Command-R support * Add StoppingCriteriaList for C4AI Command-R support * Revert deletion of MaxNewTokensCriteria --- optimum/habana/transformers/modeling_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index 2b7bb32bce..0b04015f89 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -26,6 +26,7 @@ gaudi_MaxNewTokensCriteria_call, gaudi_MaxTimeCriteria_call, gaudi_StoppingCriteriaList_call, + gaudi_StoppingCriteriaList_call, ) from .models import ( DeciLMConfig, @@ -271,6 +272,7 @@ def adapt_transformers_to_gaudi(): transformers.generation.MaxTimeCriteria.__call__ = gaudi_MaxTimeCriteria_call transformers.generation.EosTokenCriteria.__call__ = gaudi_EosTokenCriteria_call transformers.generation.StoppingCriteriaList.__call__ = gaudi_StoppingCriteriaList_call + transformers.generation.StoppingCriteriaList.__call__ = gaudi_StoppingCriteriaList_call # Optimization for BLOOM generation on Gaudi transformers.models.bloom.modeling_bloom.BloomAttention.forward = gaudi_bloom_attention_forward From 51e69252b7872e1bb1b876539e266cd8996a00ad Mon Sep 17 00:00:00 2001 From: Soila Kavulya Date: Wed, 3 Apr 2024 10:20:40 -0700 Subject: [PATCH 03/14] Fix inputs for Cohere Command-R --- examples/text-generation/run_generation.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index c948287399..4ccc52184f 100755 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -385,9 +385,6 @@ def assemble_prompt(prompt_size, book_path): 1342, # Pride and Prejudice ] input_sentences = assemble_prompt(prompt_size=args.max_input_tokens, book_path=download_book(book_ids[0])) - elif model_config.model_type == "cohere": - messages = [{"role": "user", "content": args.prompt}] - input_sentences = tokenizer.apply_chat_template(messages, tokenize=False, return_tensors="pt") else: input_sentences = [ "DeepSpeed is a machine learning framework", @@ -400,6 +397,11 @@ def assemble_prompt(prompt_size, book_path): "Peace is the only way", ] + if model.config.model_type == "cohere": + for i, sentence in enumerate(input_sentences): + message = [{"role": "user", "content": sentence}] + input_sentences[i] = tokenizer.apply_chat_template(message, tokenize=False) + if args.batch_size > len(input_sentences): # Dynamically extends to support larger batch sizes num_sentences_to_add = args.batch_size - len(input_sentences) From f8c749417e79b6585ba4b73f8b060f17cf6f23da Mon Sep 17 00:00:00 2001 From: Soila Kavulya Date: Thu, 4 Apr 2024 10:28:43 -0700 Subject: [PATCH 04/14] Add temperature to text generation config --- examples/text-generation/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index c535acba0a..d40b946c46 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -577,6 +577,7 @@ def setup_generation_config(args, model, assistant_model, tokenizer): generation_config.flash_attention_causal_mask = args.flash_attention_causal_mask generation_config.flash_attention_fast_softmax = args.flash_attention_fast_softmax generation_config.trust_remote_code = args.trust_remote_code + generation_config.temperature = args.temperature return generation_config From 1e291871cd3e012f2995c76dc0ba17687afcd053 Mon Sep 17 00:00:00 2001 From: Vidya S Galli Date: Fri, 31 May 2024 21:11:10 +0000 Subject: [PATCH 05/14] fixes --- optimum/habana/transformers/modeling_utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index 0b04015f89..2b7bb32bce 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -26,7 +26,6 @@ gaudi_MaxNewTokensCriteria_call, gaudi_MaxTimeCriteria_call, gaudi_StoppingCriteriaList_call, - gaudi_StoppingCriteriaList_call, ) from .models import ( DeciLMConfig, @@ -272,7 +271,6 @@ def adapt_transformers_to_gaudi(): transformers.generation.MaxTimeCriteria.__call__ = gaudi_MaxTimeCriteria_call transformers.generation.EosTokenCriteria.__call__ = gaudi_EosTokenCriteria_call transformers.generation.StoppingCriteriaList.__call__ = gaudi_StoppingCriteriaList_call - transformers.generation.StoppingCriteriaList.__call__ = gaudi_StoppingCriteriaList_call # Optimization for BLOOM generation on Gaudi transformers.models.bloom.modeling_bloom.BloomAttention.forward = gaudi_bloom_attention_forward From cd077719ad93a2f123e08f909901b7ff7438f0a2 Mon Sep 17 00:00:00 2001 From: Vidya S Galli Date: Fri, 31 May 2024 21:53:16 +0000 Subject: [PATCH 06/14] Added test and documentation Co-authored-by: Soila Kavulya --- optimum/habana/transformers/generation/utils.py | 1 + tests/test_text_generation_example.py | 1 + 2 files changed, 2 insertions(+) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index cbab696699..d333986679 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -1103,6 +1103,7 @@ def generate( # 7. determine generation mode generation_mode = generation_config.get_generation_mode(assistant_model) + if generation_config.bucket_size > 0: assert generation_config.static_shapes, "bucket_size > 0 can be set only when static_shapes is set" # if generation_config.bucket_size <= 0, padding is handled by the generating fn (like greedy_search) diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py index 9c4e983576..01460a32be 100644 --- a/tests/test_text_generation_example.py +++ b/tests/test_text_generation_example.py @@ -40,6 +40,7 @@ ("adept/persimmon-8b-base", 4, False, 366.73968820698406), ("Qwen/Qwen1.5-7B", 4, False, 518.894516133132), ("google/gemma-7b", 1, False, 109.70751574382221), + ("CohereForAI/c4ai-command-r-v01", 1, False, 30.472430202916325), ("state-spaces/mamba-130m-hf", 1536, False, 8600), ("Deci/DeciLM-7B", 1, False, 120), ], From 9c98f74a263c54bce89ded80da17bdc0c69d4af8 Mon Sep 17 00:00:00 2001 From: Vidya S Galli Date: Tue, 9 Jul 2024 20:23:48 +0000 Subject: [PATCH 07/14] added comment --- examples/text-generation/run_generation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index 4ccc52184f..997c2ad555 100755 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -397,6 +397,7 @@ def assemble_prompt(prompt_size, book_path): "Peace is the only way", ] + # Format message with the command-r chat template if model.config.model_type == "cohere": for i, sentence in enumerate(input_sentences): message = [{"role": "user", "content": sentence}] From eb0319032b43a680462098c79396f048931ae8d1 Mon Sep 17 00:00:00 2001 From: Vidya S Galli Date: Wed, 10 Jul 2024 19:56:25 +0000 Subject: [PATCH 08/14] added chat template option --- examples/text-generation/README.md | 26 +++++++++++++++++++ examples/text-generation/run_generation.py | 16 ++++++++---- .../sample_command_r_template.json | 1 + tests/test_text_generation_example.py | 5 ++++ 4 files changed, 43 insertions(+), 5 deletions(-) create mode 100644 examples/text-generation/sample_command_r_template.json diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index b720936ff4..ab83b6f1eb 100755 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -589,3 +589,29 @@ deepspeed --num_gpus 8 run_lm_eval.py \ ## Text-Generation Pipeline A Transformers-like pipeline is defined and provided [here](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation/text-generation-pipeline). It is optimized for Gaudi and can be called to generate text in your scripts. + +## Conversation generation + +For models that support chat like `CohereForAI/c4ai-command-r-v01` you can provide `--chat_template ` that is applied to the tokenizer. + +### Examples + +Sample chat template `sample_command_r_template.json` for [CohereForAI/c4ai-command-r-v01](https://huggingface.co/CohereForAI/c4ai-command-r-v01) is shown below: + +```json +[{"role": "user", "content": "Hello, how are you?"}] +``` + +Command to run chat generation: + +``` +python run_generation.py \ + --model_name_or_path CohereForAI/c4ai-command-r-v01 \ + --use_hpu_graphs \ + --use_kv_cache \ + --max_new_tokens 100 \ + --do_sample \ + --chat_template sample_command_r_template.json \ + --bf16 \ + --batch_size 2 +``` diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index 997c2ad555..034209aa24 100755 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -312,6 +312,12 @@ def setup_parser(parser): help="Run multi card with the specified parallel strategy. Choices are 'tp' for Tensor Parallel Strategy or 'none'.", ) + parser.add_argument( + "--chat_template", + default=None, + type=str, + help='Optional JSON input file containing chat template for tokenizer.', + ) args = parser.parse_args() if args.torch_compile: @@ -397,11 +403,11 @@ def assemble_prompt(prompt_size, book_path): "Peace is the only way", ] - # Format message with the command-r chat template - if model.config.model_type == "cohere": - for i, sentence in enumerate(input_sentences): - message = [{"role": "user", "content": sentence}] - input_sentences[i] = tokenizer.apply_chat_template(message, tokenize=False) + # Apply tokenizer chat template + if args.chat_template and hasattr(tokenizer, 'chat_template'): + with open(args.chat_template, 'r') as fh: + messages = json.load(fh) + input_sentences = [tokenizer.apply_chat_template(messages, tokenize=False)] if args.batch_size > len(input_sentences): # Dynamically extends to support larger batch sizes diff --git a/examples/text-generation/sample_command_r_template.json b/examples/text-generation/sample_command_r_template.json new file mode 100644 index 0000000000..ddfd802b2d --- /dev/null +++ b/examples/text-generation/sample_command_r_template.json @@ -0,0 +1 @@ +[{"role": "user", "content": "Hello, how are you?"}] \ No newline at end of file diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py index 01460a32be..76ac0111e3 100644 --- a/tests/test_text_generation_example.py +++ b/tests/test_text_generation_example.py @@ -236,6 +236,11 @@ def _test_text_generation( f"--parallel_strategy={parallel_strategy}", ] + if "command_r" in model_name.lower(): + path_to_template = os.path.join( + path_to_example_dir,"text-generation/sample_command_r_template.json") + command += [f"--chat_template {path_to_template}"] + with TemporaryDirectory() as tmp_dir: command.append(f"--output_dir {tmp_dir}") command.append(f"--token {token.value}") From 73d1d7ade0e88e20da6c80406a92d8a87eb9b573 Mon Sep 17 00:00:00 2001 From: Vidya S Galli Date: Wed, 10 Jul 2024 20:00:51 +0000 Subject: [PATCH 09/14] formatting --- examples/text-generation/run_generation.py | 8 ++++---- tests/test_text_generation_example.py | 3 +-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index 034209aa24..af5f9bd244 100755 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -316,7 +316,7 @@ def setup_parser(parser): "--chat_template", default=None, type=str, - help='Optional JSON input file containing chat template for tokenizer.', + help="Optional JSON input file containing chat template for tokenizer.", ) args = parser.parse_args() @@ -403,9 +403,9 @@ def assemble_prompt(prompt_size, book_path): "Peace is the only way", ] - # Apply tokenizer chat template - if args.chat_template and hasattr(tokenizer, 'chat_template'): - with open(args.chat_template, 'r') as fh: + # Apply tokenizer chat template if supported + if args.chat_template and hasattr(tokenizer, "chat_template"): + with open(args.chat_template, "r") as fh: messages = json.load(fh) input_sentences = [tokenizer.apply_chat_template(messages, tokenize=False)] diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py index 76ac0111e3..018021886c 100644 --- a/tests/test_text_generation_example.py +++ b/tests/test_text_generation_example.py @@ -237,8 +237,7 @@ def _test_text_generation( ] if "command_r" in model_name.lower(): - path_to_template = os.path.join( - path_to_example_dir,"text-generation/sample_command_r_template.json") + path_to_template = os.path.join(path_to_example_dir, "text-generation/sample_command_r_template.json") command += [f"--chat_template {path_to_template}"] with TemporaryDirectory() as tmp_dir: From d5d846d673cc82c93985a7b08aa17f1aeed7edf5 Mon Sep 17 00:00:00 2001 From: Vidya S Galli Date: Thu, 11 Jul 2024 20:23:16 +0000 Subject: [PATCH 10/14] additional checks --- examples/text-generation/run_generation.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index af5f9bd244..f70b692f70 100755 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -23,6 +23,7 @@ import logging import math import os +import sys import time from itertools import cycle from pathlib import Path @@ -370,8 +371,6 @@ def download_book(book_id): return save_path else: print("Failed to download book! Exiting...") - import sys - sys.exit() def assemble_prompt(prompt_size, book_path): @@ -406,8 +405,16 @@ def assemble_prompt(prompt_size, book_path): # Apply tokenizer chat template if supported if args.chat_template and hasattr(tokenizer, "chat_template"): with open(args.chat_template, "r") as fh: - messages = json.load(fh) - input_sentences = [tokenizer.apply_chat_template(messages, tokenize=False)] + try: + messages = json.load(fh) + except json.JSONDecodeError as e: + logger.error(f"Error loading {args.chat_template}: {e}") + sys.exit() + try: + input_sentences = [tokenizer.apply_chat_template(messages, tokenize=False)] + except Exception as e: + logger.error(f"Error applying chat template to tokenizer: {e}") + sys.exit() if args.batch_size > len(input_sentences): # Dynamically extends to support larger batch sizes From dbcfb16192eb1ea66b4ef1de768bae04ddcef585 Mon Sep 17 00:00:00 2001 From: Vidya S Galli Date: Wed, 24 Jul 2024 21:48:53 +0000 Subject: [PATCH 11/14] Fix for option name and usage --- examples/text-generation/README.md | 6 ++--- examples/text-generation/run_generation.py | 25 ++++++++++++++----- ...son => sample_command_r_conversation.json} | 0 tests/test_text_generation_example.py | 4 +-- 4 files changed, 24 insertions(+), 11 deletions(-) rename examples/text-generation/{sample_command_r_template.json => sample_command_r_conversation.json} (100%) diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index ab83b6f1eb..a83bb91bfb 100755 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -592,11 +592,11 @@ A Transformers-like pipeline is defined and provided [here](https://github.com/h ## Conversation generation -For models that support chat like `CohereForAI/c4ai-command-r-v01` you can provide `--chat_template ` that is applied to the tokenizer. +For models that support chat like `CohereForAI/c4ai-command-r-v01` you can provide `--conversation_input ` that is applied to the tokenizer. ### Examples -Sample chat template `sample_command_r_template.json` for [CohereForAI/c4ai-command-r-v01](https://huggingface.co/CohereForAI/c4ai-command-r-v01) is shown below: +Sample conversation `sample_command_r_conversation.json` for [CohereForAI/c4ai-command-r-v01](https://huggingface.co/CohereForAI/c4ai-command-r-v01) is shown below: ```json [{"role": "user", "content": "Hello, how are you?"}] @@ -611,7 +611,7 @@ python run_generation.py \ --use_kv_cache \ --max_new_tokens 100 \ --do_sample \ - --chat_template sample_command_r_template.json \ + --conversation_input sample_command_r_conversation.json \ --bf16 \ --batch_size 2 ``` diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index f70b692f70..596e079ef3 100755 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -314,7 +314,20 @@ def setup_parser(parser): ) parser.add_argument( - "--chat_template", + "--load_quantized_model", + action="store_true", + help="Whether to load model from hugging face checkpoint.", + ) + parser.add_argument( + "--parallel_strategy", + type=str, + choices=["tp", "none"], # Add other strategies as needed + default="none", + help="Run multi card with the specified parallel strategy. Choices are 'tp' for Tensor Parallel Strategy or 'none'.", + ) + + parser.add_argument( + "--conversation_input", default=None, type=str, help="Optional JSON input file containing chat template for tokenizer.", @@ -402,16 +415,16 @@ def assemble_prompt(prompt_size, book_path): "Peace is the only way", ] - # Apply tokenizer chat template if supported - if args.chat_template and hasattr(tokenizer, "chat_template"): - with open(args.chat_template, "r") as fh: + # Apply input as conversation if tokenizer has a chat template + if args.conversation_input and hasattr(tokenizer, "chat_template"): + with open(args.conversation_input, "r") as fh: try: messages = json.load(fh) except json.JSONDecodeError as e: - logger.error(f"Error loading {args.chat_template}: {e}") + logger.error(f"Error loading {args.conversation_input}: {e}") sys.exit() try: - input_sentences = [tokenizer.apply_chat_template(messages, tokenize=False)] + input_sentences = [tokenizer.apply_chat_template(conversation=messages, tokenize=False)] except Exception as e: logger.error(f"Error applying chat template to tokenizer: {e}") sys.exit() diff --git a/examples/text-generation/sample_command_r_template.json b/examples/text-generation/sample_command_r_conversation.json similarity index 100% rename from examples/text-generation/sample_command_r_template.json rename to examples/text-generation/sample_command_r_conversation.json diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py index 018021886c..bbd38e67e5 100644 --- a/tests/test_text_generation_example.py +++ b/tests/test_text_generation_example.py @@ -237,8 +237,8 @@ def _test_text_generation( ] if "command_r" in model_name.lower(): - path_to_template = os.path.join(path_to_example_dir, "text-generation/sample_command_r_template.json") - command += [f"--chat_template {path_to_template}"] + path_to_conv = os.path.join(path_to_example_dir, "text-generation/sample_command_r_conversation.json") + command += [f"--conversation_input {path_to_conv}"] with TemporaryDirectory() as tmp_dir: command.append(f"--output_dir {tmp_dir}") From 22ee28478815e9add1d721cf830f74a9a53a27a7 Mon Sep 17 00:00:00 2001 From: Vidya S Galli Date: Thu, 29 Aug 2024 11:06:06 -0700 Subject: [PATCH 12/14] rebase --- examples/text-generation/run_generation.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index 596e079ef3..a2c6f62acf 100755 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -313,19 +313,6 @@ def setup_parser(parser): help="Run multi card with the specified parallel strategy. Choices are 'tp' for Tensor Parallel Strategy or 'none'.", ) - parser.add_argument( - "--load_quantized_model", - action="store_true", - help="Whether to load model from hugging face checkpoint.", - ) - parser.add_argument( - "--parallel_strategy", - type=str, - choices=["tp", "none"], # Add other strategies as needed - default="none", - help="Run multi card with the specified parallel strategy. Choices are 'tp' for Tensor Parallel Strategy or 'none'.", - ) - parser.add_argument( "--conversation_input", default=None, From 438f13724971cbf390ad1ff408b39dc68313cc72 Mon Sep 17 00:00:00 2001 From: Vidya Galli Date: Fri, 6 Sep 2024 10:32:02 -0700 Subject: [PATCH 13/14] Update examples/text-generation/run_generation.py Co-authored-by: Yaser Afshar --- examples/text-generation/run_generation.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index a2c6f62acf..3004a61f83 100755 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -405,11 +405,7 @@ def assemble_prompt(prompt_size, book_path): # Apply input as conversation if tokenizer has a chat template if args.conversation_input and hasattr(tokenizer, "chat_template"): with open(args.conversation_input, "r") as fh: - try: - messages = json.load(fh) - except json.JSONDecodeError as e: - logger.error(f"Error loading {args.conversation_input}: {e}") - sys.exit() + messages = json.load(fh) try: input_sentences = [tokenizer.apply_chat_template(conversation=messages, tokenize=False)] except Exception as e: From ef3fba8559155294b1cd0b99e77c35b4a58af5c5 Mon Sep 17 00:00:00 2001 From: Vidya S Galli Date: Fri, 6 Sep 2024 13:34:14 -0700 Subject: [PATCH 14/14] moved input check --- examples/text-generation/run_generation.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index 3004a61f83..8bae8351c7 100755 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -317,7 +317,7 @@ def setup_parser(parser): "--conversation_input", default=None, type=str, - help="Optional JSON input file containing chat template for tokenizer.", + help="Optional JSON input file containing conversation input.", ) args = parser.parse_args() @@ -390,6 +390,14 @@ def assemble_prompt(prompt_size, book_path): 1342, # Pride and Prejudice ] input_sentences = assemble_prompt(prompt_size=args.max_input_tokens, book_path=download_book(book_ids[0])) + elif args.conversation_input and hasattr(tokenizer, "chat_template"): + with open(args.conversation_input, "r") as fh: + messages = json.load(fh) + try: + input_sentences = [tokenizer.apply_chat_template(conversation=messages, tokenize=False)] + except Exception as e: + logger.error(f"Error applying chat template to tokenizer: {e}") + sys.exit() else: input_sentences = [ "DeepSpeed is a machine learning framework", @@ -402,15 +410,6 @@ def assemble_prompt(prompt_size, book_path): "Peace is the only way", ] - # Apply input as conversation if tokenizer has a chat template - if args.conversation_input and hasattr(tokenizer, "chat_template"): - with open(args.conversation_input, "r") as fh: - messages = json.load(fh) - try: - input_sentences = [tokenizer.apply_chat_template(conversation=messages, tokenize=False)] - except Exception as e: - logger.error(f"Error applying chat template to tokenizer: {e}") - sys.exit() if args.batch_size > len(input_sentences): # Dynamically extends to support larger batch sizes