diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py index 4c1b68e478..01bb3b01b4 100644 --- a/optimum/habana/transformers/trainer.py +++ b/optimum/habana/transformers/trainer.py @@ -682,7 +682,7 @@ def _inner_training_loop( import transformers.modeling_utils - if args.deepspeed and args.use_lazy_mode: + if args.deepspeed: from deepspeed.runtime.activation_checkpointing.checkpointing import ( CheckpointFunction, non_reentrant_checkpoint,