Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🎨 Fix gradient checkpointing warnings #49

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions turbo_alignment/cherry_picks/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ def _get_dataset_metrics(
ref_model: dict = kwargs.get('ref_model', None)
sft_model: dict = kwargs.get('sft_model', None)

if model.is_gradient_checkpointing:
model.config.use_cache = False

generator = ChatGenerator(
model=model,
tokenizer=tokenizer,
Expand All @@ -45,6 +48,8 @@ def _get_dataset_metrics(
return_logits=True,
)

model.config.use_cache = True

batch_size = self._generator_transformers_settings.num_return_sequences

generations = generator.generate_from_dataset(dataset)
Expand Down
5 changes: 1 addition & 4 deletions turbo_alignment/common/tf/callbacks/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,7 @@ def _log(self, logs: dict[str, Any], state: TrainerState) -> None:
self._wandb_run.log({**rewritten_logs, 'train/global_step': state.global_step}, step=state.global_step)

def _fix_table_type(self, logs: dict[str, Any]) -> dict[str, Any]:
return {
k: wandb.Table(dataframe=v) if isinstance(v, pd.DataFrame) else v
for k, v in logs.items()
}
return {k: wandb.Table(dataframe=v) if isinstance(v, pd.DataFrame) else v for k, v in logs.items()}


class ClearMLLoggingCallback(LoggingCallback):
Expand Down
1 change: 1 addition & 0 deletions turbo_alignment/metrics/distinctness.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections import defaultdict

from transformers.tokenization_utils_base import PreTrainedTokenizerBase

from turbo_alignment.metrics.metric import Metric
Expand Down
1 change: 1 addition & 0 deletions turbo_alignment/metrics/registry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from enum import Enum

from pydantic import field_validator

from turbo_alignment.common.registry import Registrable
Expand Down
7 changes: 6 additions & 1 deletion turbo_alignment/pipelines/train/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def _get_training_args(experiment_settings: ClassificationTrainExperimentSetting
output_dir=str(experiment_settings.log_path / TRAINER_LOGS_FOLDER),
label_names=['labels'],
remove_unused_columns=False,
**experiment_settings.trainer_settings.dict(exclude={'loss_settings'}),
**experiment_settings.trainer_settings.dict(
exclude={'loss_settings', 'gradient_checkpointing_kwargs', 'gradient_checkpointing'},
),
)

@staticmethod
Expand All @@ -81,6 +83,9 @@ def _get_trainer(
val_dataset: Dataset,
data_collator: DataCollatorMixin,
):
if experiment_settings.trainer_settings.gradient_checkpointing:
model.gradient_checkpointing_enable({'use_reentrant': True})

if experiment_settings.trainer_settings.loss_settings.alpha == 'auto':
experiment_settings.trainer_settings.loss_settings.alpha = auto_class_weights(train_dataset)
logger.info(f'Auto computed class weights: {experiment_settings.trainer_settings.loss_settings.alpha}')
Expand Down
7 changes: 5 additions & 2 deletions turbo_alignment/pipelines/train/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def _get_training_args(experiment_settings: DPOTrainExperimentSettings) -> DPOTr
output_dir=str(experiment_settings.log_path / TRAINER_LOGS_FOLDER),
label_names=[],
remove_unused_columns=False,
**experiment_settings.trainer_settings.dict(),
**experiment_settings.trainer_settings.dict(
exclude={'gradient_checkpointing_kwargs', 'gradient_checkpointing'},
),
)

@staticmethod
Expand All @@ -72,7 +74,8 @@ def _get_trainer(
val_dataset: Dataset,
data_collator: Callable,
):
model.config.use_cache = not training_args.gradient_checkpointing
if experiment_settings.trainer_settings.gradient_checkpointing:
model.gradient_checkpointing_enable({'use_reentrant': True})

extra_args = {}
if experiment_settings.trainer_settings.use_ref_model:
Expand Down
7 changes: 5 additions & 2 deletions turbo_alignment/pipelines/train/kto.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def _get_training_args(experiment_settings: KTOTrainExperimentSettings) -> KTOTr
output_dir=str(experiment_settings.log_path / TRAINER_LOGS_FOLDER),
label_names=[],
remove_unused_columns=False,
**experiment_settings.trainer_settings.dict(),
**experiment_settings.trainer_settings.dict(
exclude={'gradient_checkpointing_kwargs', 'gradient_checkpointing'},
),
)

@staticmethod
Expand All @@ -72,7 +74,8 @@ def _get_trainer(
val_dataset: Dataset,
data_collator: Callable,
):
model.config.use_cache = not experiment_settings.trainer_settings.gradient_checkpointing
if experiment_settings.trainer_settings.gradient_checkpointing:
model.gradient_checkpointing_enable({'use_reentrant': True})

extra_args = {}
if experiment_settings.trainer_settings.use_ref_model:
Expand Down
9 changes: 7 additions & 2 deletions turbo_alignment/pipelines/train/rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def _get_training_args(experiment_settings: RMTrainExperimentSettings) -> Traini
output_dir=str(experiment_settings.log_path / TRAINER_LOGS_FOLDER),
label_names=[],
remove_unused_columns=False,
**experiment_settings.trainer_settings.dict(),
**experiment_settings.trainer_settings.dict(
exclude={'gradient_checkpointing_kwargs', 'gradient_checkpointing'},
),
)

@staticmethod
Expand All @@ -74,7 +76,10 @@ def _get_trainer(
val_dataset: Dataset,
data_collator: DataCollatorMixin,
**_kwargs,
):
) -> RMTrainer:
if experiment_settings.trainer_settings.gradient_checkpointing:
model.gradient_checkpointing_enable({'use_reentrant': True})

return RMTrainer(
model=model,
tokenizer=tokenizer,
Expand Down
7 changes: 5 additions & 2 deletions turbo_alignment/pipelines/train/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ def _get_cherry_pick_callback(
def _get_training_args(experiment_settings: SftTrainExperimentSettings) -> TrainingArguments:
return TrainingArguments(
output_dir=str(experiment_settings.log_path / TRAINER_LOGS_FOLDER),
**experiment_settings.trainer_settings.dict(),
**experiment_settings.trainer_settings.dict(
exclude={'gradient_checkpointing_kwargs', 'gradient_checkpointing'},
),
)

@staticmethod
Expand All @@ -72,7 +74,8 @@ def _get_trainer(
data_collator: DataCollatorMixin,
**_kwargs,
) -> MultiGPUCherryPicksTrainer:
model.config.use_cache = not experiment_settings.trainer_settings.gradient_checkpointing
if experiment_settings.trainer_settings.gradient_checkpointing:
model.gradient_checkpointing_enable({'use_reentrant': True})

return MultiGPUCherryPicksTrainer(
model=model,
Expand Down
5 changes: 1 addition & 4 deletions turbo_alignment/trainers/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
from turbo_alignment.settings.pipelines.train.dpo import (
APODownLossSettings,
APOZeroLossSettings,
ASFTLossSettings,
CPOLossSettings,
DPOLossesType,
ASFTLossSettings,
HingeLossSettings,
IPOLossSettings,
KTOLossSettings,
Expand Down Expand Up @@ -749,9 +749,6 @@ def _compute_metrics(
metrics[f'{prefix_name}grad_term'] = (
(self.dpo_loss_registry.beta * F.sigmoid(rejected_rewards - chosen_rewards)).detach().cpu().mean().item()
)
metrics[f'{prefix_name}grad_term_std'] = (
(self.dpo_loss_registry.beta * F.sigmoid(rejected_rewards - chosen_rewards)).detach().cpu().std().item()
)

return metrics

Expand Down
Loading