From 43ffb785c0b2f24948d2011883d40dccb609d341 Mon Sep 17 00:00:00 2001 From: Billy Cao Date: Fri, 5 Jul 2024 01:20:49 +0800 Subject: [PATCH] Add torch_empty_cache_steps to TrainingArguments (#31546) * Add torch_empty_cache_steps to TrainingArguments * Fix formatting * Add torch_empty_cache_steps to docs on single gpu training * Remove check for torch_empty_cache_steps <= max_steps * Captalize Tip * Be device agnostic * Fix linting --- docs/source/en/perf_train_gpu_one.md | 25 +++++++++++++------------ src/transformers/trainer.py | 19 +++++++++++++++++++ src/transformers/training_args.py | 24 ++++++++++++++++++++++++ 3 files changed, 56 insertions(+), 12 deletions(-) diff --git a/docs/source/en/perf_train_gpu_one.md b/docs/source/en/perf_train_gpu_one.md index 990df0340bf..5a72bba768d 100644 --- a/docs/source/en/perf_train_gpu_one.md +++ b/docs/source/en/perf_train_gpu_one.md @@ -41,21 +41,22 @@ hyperparameter tuning, you should determine which batch size yields the best res The methods and tools covered in this guide can be classified based on the effect they have on the training process: -| Method/tool | Improves training speed | Optimizes memory utilization | -|:-----------------------------------------------------------|:------------------------|:-----------------------------| -| [Batch size choice](#batch-size-choice) | Yes | Yes | -| [Gradient accumulation](#gradient-accumulation) | No | Yes | -| [Gradient checkpointing](#gradient-checkpointing) | No | Yes | -| [Mixed precision training](#mixed-precision-training) | Yes | (No) | -| [Optimizer choice](#optimizer-choice) | Yes | Yes | -| [Data preloading](#data-preloading) | Yes | No | -| [DeepSpeed Zero](#deepspeed-zero) | No | Yes | -| [torch.compile](#using-torchcompile) | Yes | No | -| [Parameter-Efficient Fine Tuning (PEFT)](#using--peft) | No | Yes | +| Method/tool | Improves training speed | Optimizes memory utilization | +|:--------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------|:-----------------------------| +| [Batch size choice](#batch-size-choice) | Yes | Yes | +| [Gradient accumulation](#gradient-accumulation) | No | Yes | +| [Gradient checkpointing](#gradient-checkpointing) | No | Yes | +| [Mixed precision training](#mixed-precision-training) | Yes | Maybe* | +| [torch_empty_cache_steps](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments.torch_empty_cache_steps) | No | Yes | +| [Optimizer choice](#optimizer-choice) | Yes | Yes | +| [Data preloading](#data-preloading) | Yes | No | +| [DeepSpeed Zero](#deepspeed-zero) | No | Yes | +| [torch.compile](#using-torchcompile) | Yes | No | +| [Parameter-Efficient Fine Tuning (PEFT)](#using--peft) | No | Yes | -Note: when using mixed precision with a small model and a large batch size, there will be some memory savings but with a +*Note: when using mixed precision with a small model and a large batch size, there will be some memory savings but with a large model and a small batch size, the memory use will be larger. diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index a0257151c65..8cd8858312d 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -221,6 +221,11 @@ DistributedDataParallelKwargs, DistributedType, GradientAccumulationPlugin, + is_mlu_available, + is_mps_available, + is_npu_available, + is_torch_version, + is_xpu_available, load_fsdp_model, load_fsdp_optimizer, save_fsdp_model, @@ -3307,6 +3312,20 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, loss = self.compute_loss(model, inputs) del inputs + if ( + self.args.torch_empty_cache_steps is not None + and self.state.global_step % self.args.torch_empty_cache_steps == 0 + ): + if is_xpu_available(): + torch.xpu.empty_cache() + elif is_mlu_available(): + torch.mlu.empty_cache() + elif is_npu_available(): + torch.npu.empty_cache() + elif is_torch_version(">=", "2.0") and is_mps_available(): + torch.mps.empty_cache() + else: + torch.cuda.empty_cache() kwargs = {} diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 5eff032774e..a782f4bf7f9 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -267,6 +267,15 @@ class TrainingArguments: eval_delay (`float`, *optional*): Number of epochs or steps to wait for before the first evaluation can be performed, depending on the eval_strategy. + torch_empty_cache_steps (`int`, *optional*): + Number of steps to wait before calling `torch..empty_cache()`. If left unset or set to None, cache will not be emptied. + + + + This can help avoid CUDA out-of-memory errors by lowering peak VRAM usage at a cost of about [10% slower performance](https://github.com/huggingface/transformers/issues/31372). + + + learning_rate (`float`, *optional*, defaults to 5e-5): The initial learning rate for [`AdamW`] optimizer. weight_decay (`float`, *optional*, defaults to 0): @@ -851,6 +860,15 @@ class TrainingArguments: }, ) + torch_empty_cache_steps: Optional[int] = field( + default=None, + metadata={ + "help": "Number of steps to wait before calling `torch..empty_cache()`." + "This can help avoid CUDA out-of-memory errors by lowering peak VRAM usage at a cost of about [10% slower performance](https://github.com/huggingface/transformers/issues/31372)." + "If left unset or set to None, cache will not be emptied." + }, + ) + learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."}) weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."}) adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"}) @@ -1532,6 +1550,12 @@ def __post_init__(self): if self.do_eval is False and self.eval_strategy != IntervalStrategy.NO: self.do_eval = True + if self.torch_empty_cache_steps is not None: + if not (isinstance(self.torch_empty_cache_steps, int) or self.torch_empty_cache_steps > 0): + raise ValueError( + f"`torch_empty_cache_steps` must be an integer bigger than 0, got {self.torch_empty_cache_steps}." + ) + # eval_steps has to be defined and non-zero, fallbacks to logging_steps if the latter is non-zero if self.eval_strategy == IntervalStrategy.STEPS and (self.eval_steps is None or self.eval_steps == 0): if self.logging_steps > 0: