Skip to content

Commit

Permalink
debug for deepspeed training
Browse files Browse the repository at this point in the history
  • Loading branch information
d.taranets committed Nov 29, 2024
1 parent 1ea49e4 commit 9fef70f
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 19 deletions.
7 changes: 4 additions & 3 deletions turbo_alignment/pipelines/train/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,9 @@ def _add_trainer_callbacks(self, experiment_settings: ExperimentSettingsT, **kwa
)
)

cherry_pick_callback = self._get_cherry_pick_callback(experiment_settings, self.tokenizer, **kwargs)

# cherry_pick_callback = self._get_cherry_pick_callback(experiment_settings, self.tokenizer, **kwargs)
cherry_pick_callback = None

if cherry_pick_callback is not None:
self.trainer.add_callback(cherry_pick_callback)

Expand Down Expand Up @@ -185,7 +186,7 @@ def run(self, experiment_settings: ExperimentSettingsT) -> None:
if self.trainer.accelerator.is_main_process:
self._dataset_and_collator_sanity_check(train_dataset, data_collator)

# self._add_trainer_callbacks(experiment_settings)
self._add_trainer_callbacks(experiment_settings)

os.makedirs(self.trainer.args.output_dir, exist_ok=True)
self._save_experiment_config(
Expand Down
10 changes: 6 additions & 4 deletions turbo_alignment/pipelines/train/sft_rm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Callable

import torch
from torch import nn
from torch.utils.data import Dataset
from transformers import (
Expand Down Expand Up @@ -39,15 +40,15 @@ def __init__(self, config, model_settings, tokenizer):
self.decoder = model.model

self.lm_head = model.lm_head
self.rm_head = nn.Linear(self.decoder.norm.weight.shape[0], 1, bias=False)
self.rm_head = nn.Linear(config.hidden_size, 1, bias=False)

def forward(self, batch):
outputs_w = self.decoder(**batch['inputs_w']).last_hidden_state[0]
outputs_l = self.decoder(**batch['inputs_l']).last_hidden_state[0]

lm_logits = self.lm_head(outputs_w)
rm_logits_w = self.rm_head(outputs_w[-1])
rm_logits_l = self.rm_head(outputs_l[-1])
rm_logits_w = self.rm_head(outputs_w[-1:,])[0]
rm_logits_l = self.rm_head(outputs_l[-1:,])[0]

return lm_logits, rm_logits_w, rm_logits_l

Expand Down Expand Up @@ -99,7 +100,8 @@ def _load_model(
tokenizer: PreTrainedTokenizerBase,
) -> nn.Module | PreTrainedModel:
config = AutoConfig.from_pretrained(experiment_settings.model_settings.model_path)
return MultiheadModel(config, experiment_settings.model_settings, tokenizer)
model = MultiheadModel(config, experiment_settings.model_settings, tokenizer)
return model

@staticmethod
def _get_trainer(
Expand Down
20 changes: 8 additions & 12 deletions turbo_alignment/trainers/sft_with_rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@


class SFTwithRMTrainer(MultiGPUCherryPicksTrainer):
def compute_loss(self, model, inputs, return_outputs=False) -> tuple[torch.Tensor, dict[str, Any]] | torch.Tensor:
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None) -> tuple[torch.Tensor, dict[str, Any]] | torch.Tensor:
sft_logits, rewards_w, rewards_l = model.forward(inputs)
sft_labels = inputs['inputs_w']['input_ids'][0]

loss = -torch.nn.functional.logsigmoid(rewards_w - rewards_l).mean() + torch.nn.functional.cross_entropy(
sft_logits, sft_labels
)
loss_rm = -torch.nn.functional.logsigmoid(rewards_w - rewards_l).mean()
loss_sft = torch.nn.functional.cross_entropy(sft_logits, sft_labels)
alpha = 0.001
loss = (1 - alpha) * loss_rm + alpha * loss_sft

if return_outputs:
return loss, {'rewards_w': rewards_w, 'rewards_l': rewards_l}
return loss
Expand Down Expand Up @@ -78,11 +80,5 @@ def _save_checkpoint(self, model, trial, metrics=None):
):
self.state.best_metric = metric_value
self.state.best_model_checkpoint = output_dir

(output_dir / 'decoder').mkdir(parents=True, exist_ok=True)

torch.save(model.module.lm_head.state_dict(), output_dir / 'lm_head.pt')
torch.save(model.module.rm_head.state_dict(), output_dir / 'rm_head.pt')

model.module.decoder.save_pretrained(output_dir / 'decoder')
self.tokenizer.save_pretrained(output_dir / 'decoder')

return super()._save_checkpoint(model, trial, metrics)

0 comments on commit 9fef70f

Please sign in to comment.