Skip to content

Commit

Permalink
Set inference_parallelize_kwargs in IPUConfig
Browse files Browse the repository at this point in the history
  • Loading branch information
callumm-graphcore committed Jul 17, 2023
1 parent 53742af commit ee5282c
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 22 deletions.
10 changes: 5 additions & 5 deletions examples/speech-recognition/run_speech_recognition_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,11 @@ def main():
training_args.ipu_config_name if training_args.ipu_config_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
inference_parallelize_kwargs={
"use_cache": True,
"use_cross_cache": True,
"max_length": training_args.generation_max_length,
},
)

if model.config.decoder_start_token_id is None:
Expand Down Expand Up @@ -546,11 +551,6 @@ def compute_metrics(pred):
eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None,
data_collator=data_collator,
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
inference_parallelize_kwargs={
"use_cache": True,
"use_cross_cache": True,
"max_length": training_args.generation_max_length,
},
)

# 12. Training
Expand Down
33 changes: 16 additions & 17 deletions tests/examples/run_speech_recognition_seq2seq.txt
Original file line number Diff line number Diff line change
Expand Up @@ -72,26 +72,31 @@
> # Whisper does not have a layer_norm_eps option, remains to be seen if this is a problem
> # config.update({"layer_norm_eps": 0.0001})
>
401a410,414
401a410,419
> ipu_config = IPUConfig.from_pretrained(
> training_args.ipu_config_name if training_args.ipu_config_name else model_args.model_name_or_path,
> cache_dir=model_args.cache_dir,
> use_auth_token=True if model_args.use_auth_token else None,
> inference_parallelize_kwargs={
> "use_cache": True,
> "use_cross_cache": True,
> "max_length": training_args.generation_max_length,
> },
> )
446c459
446c464
< def prepare_dataset(batch):
---
> def prepare_dataset(batch, feature_extractor, tokenizer):
452c465
452c470
< # process audio length
---
>
457a471,474
457a476,479
> if not training_args.fp32:
> # Cast audio inputs to FP16
> batch[model_input_name] = batch[model_input_name].astype(np.float16)
>
463,469c480,485
463,469c485,490
< with training_args.main_process_first(desc="dataset map pre-processing"):
< vectorized_datasets = raw_datasets.map(
< prepare_dataset,
Expand All @@ -106,7 +111,7 @@
> num_proc=data_args.preprocessing_num_workers,
> desc="preprocess train dataset",
> )
509,516c525,528
509,516c530,533
< # make sure all processes wait until data is saved
< with training_args.main_process_first():
< # only the main process saves them
Expand All @@ -120,23 +125,17 @@
> feature_extractor.save_pretrained(training_args.output_dir)
> tokenizer.save_pretrained(training_args.output_dir)
> config.save_pretrained(training_args.output_dir)
518c530
518c535
< processor = AutoProcessor.from_pretrained(training_args.output_dir)
---
> processor = WhisperProcessor(feature_extractor, tokenizer)
524a537
524a542
> pad_to_multiple_of_labels=training_args.generation_max_length,
528c541
528c546
< trainer = Seq2SeqTrainer(
---
> trainer = IPUSeq2SeqTrainer(
529a543
529a548
> ipu_config=ipu_config,
533d546
533d551
< tokenizer=feature_extractor,
535a549,553
> inference_parallelize_kwargs={
> "use_cache": True,
> "use_cross_cache": True,
> "max_length": training_args.generation_max_length,
> },

0 comments on commit ee5282c

Please sign in to comment.