diff --git a/turbo_alignment/pipelines/train/sft_rm.py b/turbo_alignment/pipelines/train/sft_rm.py index 58924cc..3010514 100755 --- a/turbo_alignment/pipelines/train/sft_rm.py +++ b/turbo_alignment/pipelines/train/sft_rm.py @@ -42,6 +42,8 @@ def __init__(self, config, model_settings, tokenizer): self.lm_head = model.lm_head self.rm_head = nn.Linear(config.hidden_size, 1, bias=False) + self.gradient_checkpointing_enable = self.decoder.gradient_checkpointing_enable + def forward(self, batch): outputs_w = self.decoder(**batch['inputs_w']).last_hidden_state[0] outputs_l = self.decoder(**batch['inputs_l']).last_hidden_state[0]