Skip to content

Commit

Permalink
add grad checkpointing
Browse files Browse the repository at this point in the history
  • Loading branch information
d.taranets committed Dec 3, 2024
1 parent 5cb7708 commit 7607162
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions turbo_alignment/pipelines/train/sft_rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def __init__(self, config, model_settings, tokenizer):
self.lm_head = model.lm_head
self.rm_head = nn.Linear(config.hidden_size, 1, bias=False)

self.gradient_checkpointing_enable = self.decoder.gradient_checkpointing_enable

def forward(self, batch):
outputs_w = self.decoder(**batch['inputs_w']).last_hidden_state[0]
outputs_l = self.decoder(**batch['inputs_l']).last_hidden_state[0]
Expand Down

0 comments on commit 7607162

Please sign in to comment.