Skip to content

Commit

Permalink
fix gradient checkpointing warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
Малахов Алексей Павлович committed Oct 27, 2024
1 parent 21c3693 commit 392185d
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 12 deletions.
5 changes: 5 additions & 0 deletions turbo_alignment/cherry_picks/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ def _get_dataset_metrics(
ref_model: dict = kwargs.get('ref_model', None)
sft_model: dict = kwargs.get('sft_model', None)

if model.is_gradient_checkpointing:
model.config.use_cache = False

generator = ChatGenerator(
model=model,
tokenizer=tokenizer,
Expand All @@ -45,6 +48,8 @@ def _get_dataset_metrics(
return_logits=True,
)

model.config.use_cache = True

batch_size = self._generator_transformers_settings.num_return_sequences

generations = generator.generate_from_dataset(dataset)
Expand Down
7 changes: 6 additions & 1 deletion turbo_alignment/pipelines/train/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def _get_training_args(experiment_settings: ClassificationTrainExperimentSetting
output_dir=str(experiment_settings.log_path / TRAINER_LOGS_FOLDER),
label_names=['labels'],
remove_unused_columns=False,
**experiment_settings.trainer_settings.dict(exclude={'loss_settings'}),
**experiment_settings.trainer_settings.dict(
exclude={'loss_settings', 'gradient_checkpointing_kwargs', 'gradient_checkpointing'},
),
)

@staticmethod
Expand All @@ -81,6 +83,9 @@ def _get_trainer(
val_dataset: Dataset,
data_collator: DataCollatorMixin,
):
if experiment_settings.trainer_settings.gradient_checkpointing:
model.gradient_checkpointing_enable({'use_reentrant': True})

if experiment_settings.trainer_settings.loss_settings.alpha == 'auto':
experiment_settings.trainer_settings.loss_settings.alpha = auto_class_weights(train_dataset)
logger.info(f'Auto computed class weights: {experiment_settings.trainer_settings.loss_settings.alpha}')
Expand Down
7 changes: 5 additions & 2 deletions turbo_alignment/pipelines/train/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def _get_training_args(experiment_settings: DPOTrainExperimentSettings) -> DPOTr
output_dir=str(experiment_settings.log_path / TRAINER_LOGS_FOLDER),
label_names=[],
remove_unused_columns=False,
**experiment_settings.trainer_settings.dict(),
**experiment_settings.trainer_settings.dict(
exclude={'gradient_checkpointing_kwargs', 'gradient_checkpointing'},
),
)

@staticmethod
Expand All @@ -72,7 +74,8 @@ def _get_trainer(
val_dataset: Dataset,
data_collator: Callable,
):
model.config.use_cache = not training_args.gradient_checkpointing
if experiment_settings.trainer_settings.gradient_checkpointing:
model.gradient_checkpointing_enable({'use_reentrant': True})

extra_args = {}
if experiment_settings.trainer_settings.use_ref_model:
Expand Down
7 changes: 5 additions & 2 deletions turbo_alignment/pipelines/train/kto.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def _get_training_args(experiment_settings: KTOTrainExperimentSettings) -> KTOTr
output_dir=str(experiment_settings.log_path / TRAINER_LOGS_FOLDER),
label_names=[],
remove_unused_columns=False,
**experiment_settings.trainer_settings.dict(),
**experiment_settings.trainer_settings.dict(
exclude={'gradient_checkpointing_kwargs', 'gradient_checkpointing'},
),
)

@staticmethod
Expand All @@ -72,7 +74,8 @@ def _get_trainer(
val_dataset: Dataset,
data_collator: Callable,
):
model.config.use_cache = not experiment_settings.trainer_settings.gradient_checkpointing
if experiment_settings.trainer_settings.gradient_checkpointing:
model.gradient_checkpointing_enable({'use_reentrant': True})

extra_args = {}
if experiment_settings.trainer_settings.use_ref_model:
Expand Down
9 changes: 7 additions & 2 deletions turbo_alignment/pipelines/train/rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def _get_training_args(experiment_settings: RMTrainExperimentSettings) -> Traini
output_dir=str(experiment_settings.log_path / TRAINER_LOGS_FOLDER),
label_names=[],
remove_unused_columns=False,
**experiment_settings.trainer_settings.dict(),
**experiment_settings.trainer_settings.dict(
exclude={'gradient_checkpointing_kwargs', 'gradient_checkpointing'},
),
)

@staticmethod
Expand All @@ -74,7 +76,10 @@ def _get_trainer(
val_dataset: Dataset,
data_collator: DataCollatorMixin,
**_kwargs,
):
) -> RMTrainer:
if experiment_settings.trainer_settings.gradient_checkpointing:
model.gradient_checkpointing_enable({'use_reentrant': True})

return RMTrainer(
model=model,
tokenizer=tokenizer,
Expand Down
7 changes: 5 additions & 2 deletions turbo_alignment/pipelines/train/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ def _get_cherry_pick_callback(
def _get_training_args(experiment_settings: SftTrainExperimentSettings) -> TrainingArguments:
return TrainingArguments(
output_dir=str(experiment_settings.log_path / TRAINER_LOGS_FOLDER),
**experiment_settings.trainer_settings.dict(),
**experiment_settings.trainer_settings.dict(
exclude={'gradient_checkpointing_kwargs', 'gradient_checkpointing'},
),
)

@staticmethod
Expand All @@ -72,7 +74,8 @@ def _get_trainer(
data_collator: DataCollatorMixin,
**_kwargs,
) -> MultiGPUCherryPicksTrainer:
model.config.use_cache = not experiment_settings.trainer_settings.gradient_checkpointing
if experiment_settings.trainer_settings.gradient_checkpointing:
model.gradient_checkpointing_enable({'use_reentrant': True})

return MultiGPUCherryPicksTrainer(
model=model,
Expand Down
3 changes: 0 additions & 3 deletions turbo_alignment/trainers/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,9 +749,6 @@ def _compute_metrics(
metrics[f'{prefix_name}grad_term'] = (
(self.dpo_loss_registry.beta * F.sigmoid(rejected_rewards - chosen_rewards)).detach().cpu().mean().item()
)
metrics[f'{prefix_name}grad_term_std'] = (
(self.dpo_loss_registry.beta * F.sigmoid(rejected_rewards - chosen_rewards)).detach().cpu().std().item()
)

return metrics

Expand Down

0 comments on commit 392185d

Please sign in to comment.