diff --git a/turbo_alignment/cherry_picks/chat.py b/turbo_alignment/cherry_picks/chat.py index eec477e..f8f864c 100755 --- a/turbo_alignment/cherry_picks/chat.py +++ b/turbo_alignment/cherry_picks/chat.py @@ -36,6 +36,9 @@ def _get_dataset_metrics( ref_model: dict = kwargs.get('ref_model', None) sft_model: dict = kwargs.get('sft_model', None) + if model.is_gradient_checkpointing: + model.config.use_cache = False + generator = ChatGenerator( model=model, tokenizer=tokenizer, @@ -45,6 +48,8 @@ def _get_dataset_metrics( return_logits=True, ) + model.config.use_cache = True + batch_size = self._generator_transformers_settings.num_return_sequences generations = generator.generate_from_dataset(dataset) diff --git a/turbo_alignment/pipelines/train/classification.py b/turbo_alignment/pipelines/train/classification.py index 1bc09fe..86b8857 100755 --- a/turbo_alignment/pipelines/train/classification.py +++ b/turbo_alignment/pipelines/train/classification.py @@ -68,7 +68,9 @@ def _get_training_args(experiment_settings: ClassificationTrainExperimentSetting output_dir=str(experiment_settings.log_path / TRAINER_LOGS_FOLDER), label_names=['labels'], remove_unused_columns=False, - **experiment_settings.trainer_settings.dict(exclude={'loss_settings'}), + **experiment_settings.trainer_settings.dict( + exclude={'loss_settings', 'gradient_checkpointing_kwargs', 'gradient_checkpointing'}, + ), ) @staticmethod @@ -81,6 +83,9 @@ def _get_trainer( val_dataset: Dataset, data_collator: DataCollatorMixin, ): + if experiment_settings.trainer_settings.gradient_checkpointing: + model.gradient_checkpointing_enable({'use_reentrant': True}) + if experiment_settings.trainer_settings.loss_settings.alpha == 'auto': experiment_settings.trainer_settings.loss_settings.alpha = auto_class_weights(train_dataset) logger.info(f'Auto computed class weights: {experiment_settings.trainer_settings.loss_settings.alpha}') diff --git a/turbo_alignment/pipelines/train/dpo.py b/turbo_alignment/pipelines/train/dpo.py index 155f6d3..d6d0fb8 100755 --- a/turbo_alignment/pipelines/train/dpo.py +++ b/turbo_alignment/pipelines/train/dpo.py @@ -59,7 +59,9 @@ def _get_training_args(experiment_settings: DPOTrainExperimentSettings) -> DPOTr output_dir=str(experiment_settings.log_path / TRAINER_LOGS_FOLDER), label_names=[], remove_unused_columns=False, - **experiment_settings.trainer_settings.dict(), + **experiment_settings.trainer_settings.dict( + exclude={'gradient_checkpointing_kwargs', 'gradient_checkpointing'}, + ), ) @staticmethod @@ -72,7 +74,8 @@ def _get_trainer( val_dataset: Dataset, data_collator: Callable, ): - model.config.use_cache = not training_args.gradient_checkpointing + if experiment_settings.trainer_settings.gradient_checkpointing: + model.gradient_checkpointing_enable({'use_reentrant': True}) extra_args = {} if experiment_settings.trainer_settings.use_ref_model: diff --git a/turbo_alignment/pipelines/train/kto.py b/turbo_alignment/pipelines/train/kto.py index b34d813..0d93cf3 100755 --- a/turbo_alignment/pipelines/train/kto.py +++ b/turbo_alignment/pipelines/train/kto.py @@ -59,7 +59,9 @@ def _get_training_args(experiment_settings: KTOTrainExperimentSettings) -> KTOTr output_dir=str(experiment_settings.log_path / TRAINER_LOGS_FOLDER), label_names=[], remove_unused_columns=False, - **experiment_settings.trainer_settings.dict(), + **experiment_settings.trainer_settings.dict( + exclude={'gradient_checkpointing_kwargs', 'gradient_checkpointing'}, + ), ) @staticmethod @@ -72,7 +74,8 @@ def _get_trainer( val_dataset: Dataset, data_collator: Callable, ): - model.config.use_cache = not experiment_settings.trainer_settings.gradient_checkpointing + if experiment_settings.trainer_settings.gradient_checkpointing: + model.gradient_checkpointing_enable({'use_reentrant': True}) extra_args = {} if experiment_settings.trainer_settings.use_ref_model: diff --git a/turbo_alignment/pipelines/train/rm.py b/turbo_alignment/pipelines/train/rm.py index 66ecac9..f1e212c 100755 --- a/turbo_alignment/pipelines/train/rm.py +++ b/turbo_alignment/pipelines/train/rm.py @@ -61,7 +61,9 @@ def _get_training_args(experiment_settings: RMTrainExperimentSettings) -> Traini output_dir=str(experiment_settings.log_path / TRAINER_LOGS_FOLDER), label_names=[], remove_unused_columns=False, - **experiment_settings.trainer_settings.dict(), + **experiment_settings.trainer_settings.dict( + exclude={'gradient_checkpointing_kwargs', 'gradient_checkpointing'}, + ), ) @staticmethod @@ -74,7 +76,10 @@ def _get_trainer( val_dataset: Dataset, data_collator: DataCollatorMixin, **_kwargs, - ): + ) -> RMTrainer: + if experiment_settings.trainer_settings.gradient_checkpointing: + model.gradient_checkpointing_enable({'use_reentrant': True}) + return RMTrainer( model=model, tokenizer=tokenizer, diff --git a/turbo_alignment/pipelines/train/sft.py b/turbo_alignment/pipelines/train/sft.py index a1bddec..7f83a9d 100755 --- a/turbo_alignment/pipelines/train/sft.py +++ b/turbo_alignment/pipelines/train/sft.py @@ -58,7 +58,9 @@ def _get_cherry_pick_callback( def _get_training_args(experiment_settings: SftTrainExperimentSettings) -> TrainingArguments: return TrainingArguments( output_dir=str(experiment_settings.log_path / TRAINER_LOGS_FOLDER), - **experiment_settings.trainer_settings.dict(), + **experiment_settings.trainer_settings.dict( + exclude={'gradient_checkpointing_kwargs', 'gradient_checkpointing'}, + ), ) @staticmethod @@ -72,7 +74,8 @@ def _get_trainer( data_collator: DataCollatorMixin, **_kwargs, ) -> MultiGPUCherryPicksTrainer: - model.config.use_cache = not experiment_settings.trainer_settings.gradient_checkpointing + if experiment_settings.trainer_settings.gradient_checkpointing: + model.gradient_checkpointing_enable({'use_reentrant': True}) return MultiGPUCherryPicksTrainer( model=model, diff --git a/turbo_alignment/trainers/dpo.py b/turbo_alignment/trainers/dpo.py index e49231b..e3f80e9 100755 --- a/turbo_alignment/trainers/dpo.py +++ b/turbo_alignment/trainers/dpo.py @@ -749,9 +749,6 @@ def _compute_metrics( metrics[f'{prefix_name}grad_term'] = ( (self.dpo_loss_registry.beta * F.sigmoid(rejected_rewards - chosen_rewards)).detach().cpu().mean().item() ) - metrics[f'{prefix_name}grad_term_std'] = ( - (self.dpo_loss_registry.beta * F.sigmoid(rejected_rewards - chosen_rewards)).detach().cpu().std().item() - ) return metrics