Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature] add pytorch profiling #2182

Merged
merged 4 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/config.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,10 @@ eval_table_size: # Approximate number of predictions sent to wandb depending on
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", "chrf", "perplexity"]

profiler_steps: # enable the pytorch profiler to capture the first N steps of training to the output_dir.
# see https://pytorch.org/blog/understanding-gpu-memory-1/ for more information
# snapshots can be visualized @ https://pytorch.org/memory_viz

loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)

Expand Down
8 changes: 8 additions & 0 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
log_prediction_callback_factory,
)
from axolotl.utils.callbacks.lisa import lisa_callback_factory
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
from axolotl.utils.chat_templates import get_chat_template
from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq,
Expand Down Expand Up @@ -1363,6 +1364,13 @@ def get_callbacks(self) -> List[TrainerCallback]:
plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model)
)

if self.cfg.profiler_steps:
callbacks.append(
PytorchProfilerCallback(
steps_to_profile=self.cfg.profiler_steps,
)
)

if self.cfg.use_wandb:
callbacks.append(
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
Expand Down
43 changes: 43 additions & 0 deletions src/axolotl/utils/callbacks/profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""
HF Trainer callback for creating pytorch profiling snapshots
"""
from pathlib import Path
from pickle import dump # nosec B403

import torch
from transformers import (
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
)


class PytorchProfilerCallback(TrainerCallback):
"""
PyTorch Profiler callback to create snapshots of GPU memory usage at specified steps.
"""

def __init__(self, steps_to_profile: int = 5):
self.steps_to_profile = steps_to_profile
if self.steps_to_profile:
torch.cuda.memory._record_memory_history( # pylint: disable=protected-access
enabled="all"
)

def on_step_end( # pylint: disable=unused-argument
self,
args: TrainingArguments, # pylint: disable=unused-argument
state: TrainerState,
control: TrainerControl, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
):
if state.global_step == self.steps_to_profile:
snapshot = torch.cuda.memory._snapshot() # pylint: disable=protected-access
with open(Path(args.output_dir) / "snapshot.pickle", "wb") as fout:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we maybe output to a non-pickle file for safety reasons (shared to others , tampered, etc)? How would a person also read this file and display it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't believe there are any other options that the visualizer supports as input. I added docs but forgot to push them earlier.

dump(snapshot, fout)

# tell CUDA to stop recording memory allocations now
torch.cuda.memory._record_memory_history( # pylint: disable=protected-access
enabled=None
)
1 change: 1 addition & 0 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,7 @@ class Config:
load_best_model_at_end: Optional[bool] = False
save_only_model: Optional[bool] = False
use_tensorboard: Optional[bool] = None
profiler_steps: Optional[int] = None

neftune_noise_alpha: Optional[float] = None

Expand Down