From ee5282cd1697511a518f9fce4c23f045eed6578b Mon Sep 17 00:00:00 2001 From: Callum McLean Date: Mon, 17 Jul 2023 17:12:25 +0100 Subject: [PATCH] Set `inference_parallelize_kwargs` in `IPUConfig` --- .../run_speech_recognition_seq2seq.py | 10 +++--- .../run_speech_recognition_seq2seq.txt | 33 +++++++++---------- 2 files changed, 21 insertions(+), 22 deletions(-) diff --git a/examples/speech-recognition/run_speech_recognition_seq2seq.py b/examples/speech-recognition/run_speech_recognition_seq2seq.py index b3fec1ae7..244a4a201 100755 --- a/examples/speech-recognition/run_speech_recognition_seq2seq.py +++ b/examples/speech-recognition/run_speech_recognition_seq2seq.py @@ -411,6 +411,11 @@ def main(): training_args.ipu_config_name if training_args.ipu_config_name else model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, + inference_parallelize_kwargs={ + "use_cache": True, + "use_cross_cache": True, + "max_length": training_args.generation_max_length, + }, ) if model.config.decoder_start_token_id is None: @@ -546,11 +551,6 @@ def compute_metrics(pred): eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None, data_collator=data_collator, compute_metrics=compute_metrics if training_args.predict_with_generate else None, - inference_parallelize_kwargs={ - "use_cache": True, - "use_cross_cache": True, - "max_length": training_args.generation_max_length, - }, ) # 12. Training diff --git a/tests/examples/run_speech_recognition_seq2seq.txt b/tests/examples/run_speech_recognition_seq2seq.txt index d11b90719..e259c424a 100644 --- a/tests/examples/run_speech_recognition_seq2seq.txt +++ b/tests/examples/run_speech_recognition_seq2seq.txt @@ -72,26 +72,31 @@ > # Whisper does not have a layer_norm_eps option, remains to be seen if this is a problem > # config.update({"layer_norm_eps": 0.0001}) > -401a410,414 +401a410,419 > ipu_config = IPUConfig.from_pretrained( > training_args.ipu_config_name if training_args.ipu_config_name else model_args.model_name_or_path, > cache_dir=model_args.cache_dir, > use_auth_token=True if model_args.use_auth_token else None, +> inference_parallelize_kwargs={ +> "use_cache": True, +> "use_cross_cache": True, +> "max_length": training_args.generation_max_length, +> }, > ) -446c459 +446c464 < def prepare_dataset(batch): --- > def prepare_dataset(batch, feature_extractor, tokenizer): -452c465 +452c470 < # process audio length --- > -457a471,474 +457a476,479 > if not training_args.fp32: > # Cast audio inputs to FP16 > batch[model_input_name] = batch[model_input_name].astype(np.float16) > -463,469c480,485 +463,469c485,490 < with training_args.main_process_first(desc="dataset map pre-processing"): < vectorized_datasets = raw_datasets.map( < prepare_dataset, @@ -106,7 +111,7 @@ > num_proc=data_args.preprocessing_num_workers, > desc="preprocess train dataset", > ) -509,516c525,528 +509,516c530,533 < # make sure all processes wait until data is saved < with training_args.main_process_first(): < # only the main process saves them @@ -120,23 +125,17 @@ > feature_extractor.save_pretrained(training_args.output_dir) > tokenizer.save_pretrained(training_args.output_dir) > config.save_pretrained(training_args.output_dir) -518c530 +518c535 < processor = AutoProcessor.from_pretrained(training_args.output_dir) --- > processor = WhisperProcessor(feature_extractor, tokenizer) -524a537 +524a542 > pad_to_multiple_of_labels=training_args.generation_max_length, -528c541 +528c546 < trainer = Seq2SeqTrainer( --- > trainer = IPUSeq2SeqTrainer( -529a543 +529a548 > ipu_config=ipu_config, -533d546 +533d551 < tokenizer=feature_extractor, -535a549,553 -> inference_parallelize_kwargs={ -> "use_cache": True, -> "use_cross_cache": True, -> "max_length": training_args.generation_max_length, -> },