From ae7dc0e57541fe9c40cb0acfc27a0669c87111fe Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 12 Dec 2024 20:42:18 -0500 Subject: [PATCH 1/4] add pytorch profiling --- src/axolotl/core/trainer_builder.py | 8 +++ src/axolotl/utils/callbacks/profiler.py | 50 +++++++++++++++++++ .../config/models/input/v0_4_1/__init__.py | 1 + 3 files changed, 59 insertions(+) create mode 100644 src/axolotl/utils/callbacks/profiler.py diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 73d9e0e655..0f30f511c2 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -65,6 +65,7 @@ log_prediction_callback_factory, ) from axolotl.utils.callbacks.lisa import lisa_callback_factory +from axolotl.utils.callbacks.profiler import PytorchProfilerCallback from axolotl.utils.chat_templates import get_chat_template from axolotl.utils.collators import ( BatchSamplerDataCollatorForSeq2Seq, @@ -1363,6 +1364,13 @@ def get_callbacks(self) -> List[TrainerCallback]: plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model) ) + if self.cfg.profiler_steps: + callbacks.append( + PytorchProfilerCallback( + steps_to_profile=self.cfg.profiler_steps, + ) + ) + if self.cfg.use_wandb: callbacks.append( SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path) diff --git a/src/axolotl/utils/callbacks/profiler.py b/src/axolotl/utils/callbacks/profiler.py new file mode 100644 index 0000000000..56b0a392c2 --- /dev/null +++ b/src/axolotl/utils/callbacks/profiler.py @@ -0,0 +1,50 @@ +""" +HF Trainer callback for creating pytorch profiling snapshots +""" +from pathlib import Path +from pickle import dump # nosec B403 + +import torch +from transformers import ( + TrainerCallback, + TrainerControl, + TrainerState, + TrainingArguments, +) + + +class PytorchProfilerCallback(TrainerCallback): + """ + PyTorch Profiler callback to create snapshots of GPU memory usage at specified steps. + """ + + def __init__(self, steps_to_profile: int = 5): + self.steps_to_profile = steps_to_profile + + def on_train_begin( + self, + args: TrainingArguments, # pylint: disable=unused-argument + state: TrainerState, # pylint: disable=unused-argument + control: TrainerControl, + **kwargs, # pylint: disable=unused-argument + ): + torch.cuda.memory._record_memory_history( # pylint: disable=protected-access + enabled="all" + ) + + def on_step_end( # pylint: disable=unused-argument + self, + args: TrainingArguments, # pylint: disable=unused-argument + state: TrainerState, + control: TrainerControl, # pylint: disable=unused-argument + **kwargs, # pylint: disable=unused-argument + ): + if state.global_step == self.steps_to_profile: + snapshot = torch.cuda.memory._snapshot() # pylint: disable=protected-access + with open(Path(args.output_dir) / "snapshot.pickle", "wb") as fout: + dump(snapshot, fout) + + # tell CUDA to stop recording memory allocations now + torch.cuda.memory._record_memory_history( # pylint: disable=protected-access + enabled=None + ) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 3671e1bb93..d05de2330d 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -762,6 +762,7 @@ class Config: load_best_model_at_end: Optional[bool] = False save_only_model: Optional[bool] = False use_tensorboard: Optional[bool] = None + profiler_steps: Optional[int] = None neftune_noise_alpha: Optional[float] = None From 79a6ed3bc51c791f043a7fbc7a2acaeff3db0030 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 12 Dec 2024 20:55:46 -0500 Subject: [PATCH 2/4] kick off the profiler asap since things may get allcoated before train start --- src/axolotl/utils/callbacks/profiler.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/src/axolotl/utils/callbacks/profiler.py b/src/axolotl/utils/callbacks/profiler.py index 56b0a392c2..8616963323 100644 --- a/src/axolotl/utils/callbacks/profiler.py +++ b/src/axolotl/utils/callbacks/profiler.py @@ -20,17 +20,10 @@ class PytorchProfilerCallback(TrainerCallback): def __init__(self, steps_to_profile: int = 5): self.steps_to_profile = steps_to_profile - - def on_train_begin( - self, - args: TrainingArguments, # pylint: disable=unused-argument - state: TrainerState, # pylint: disable=unused-argument - control: TrainerControl, - **kwargs, # pylint: disable=unused-argument - ): - torch.cuda.memory._record_memory_history( # pylint: disable=protected-access - enabled="all" - ) + if self.steps_to_profile: + torch.cuda.memory._record_memory_history( # pylint: disable=protected-access + enabled="all" + ) def on_step_end( # pylint: disable=unused-argument self, From b255c45107636e5311829e84d6b6ec3d0ec48683 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 12 Dec 2024 21:57:46 -0500 Subject: [PATCH 3/4] document feature --- docs/config.qmd | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/config.qmd b/docs/config.qmd index f01a2ce267..d9d259af08 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -363,6 +363,9 @@ eval_table_size: # Approximate number of predictions sent to wandb depending on eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128 eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", "chrf", "perplexity"] +profiler_steps: # enable the pytorch profiler to capture the first N steps of training to the output_dir. + # see https://pytorch.org/blog/understanding-gpu-memory-1/ for more information + loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training) loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3) From a68bd1a36d2002a46fe095d2eefeedc3f1749ca6 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 13 Dec 2024 10:47:55 -0500 Subject: [PATCH 4/4] add url for visualizer [skip ci] --- docs/config.qmd | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/config.qmd b/docs/config.qmd index d9d259af08..120aec8933 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -365,6 +365,7 @@ eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is profiler_steps: # enable the pytorch profiler to capture the first N steps of training to the output_dir. # see https://pytorch.org/blog/understanding-gpu-memory-1/ for more information + # snapshots can be visualized @ https://pytorch.org/memory_viz loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training) loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)