diff --git a/turbo_alignment/trainers/rm.py b/turbo_alignment/trainers/rm.py index 296d7d9..46d9326 100755 --- a/turbo_alignment/trainers/rm.py +++ b/turbo_alignment/trainers/rm.py @@ -6,6 +6,7 @@ from peft import PeftModel from torch import nn from transformers import PreTrainedModel +from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.trainer_pt_utils import nested_detach from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR from transformers.utils import logging @@ -70,7 +71,7 @@ def prediction_step( return loss, logits, labels def _save_checkpoint(self, model, trial, metrics=None): - if isinstance(model, PeftModel) and self.accelerator.state.deepspeed_plugin.zero_stage == 3: + if isinstance(model, PeftModel) and is_deepspeed_zero3_enabled(): logger.info('Running custom _save_checkpoint') checkpoint_folder = f'{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}' run_dir = self._get_output_dir(trial=trial)