diff --git a/turbo_alignment/pipelines/train/base.py b/turbo_alignment/pipelines/train/base.py index a7be9c3..0bcbcd9 100755 --- a/turbo_alignment/pipelines/train/base.py +++ b/turbo_alignment/pipelines/train/base.py @@ -112,8 +112,9 @@ def _add_trainer_callbacks(self, experiment_settings: ExperimentSettingsT, **kwa ) ) - cherry_pick_callback = self._get_cherry_pick_callback(experiment_settings, self.tokenizer, **kwargs) - + # cherry_pick_callback = self._get_cherry_pick_callback(experiment_settings, self.tokenizer, **kwargs) + cherry_pick_callback = None + if cherry_pick_callback is not None: self.trainer.add_callback(cherry_pick_callback) @@ -185,7 +186,7 @@ def run(self, experiment_settings: ExperimentSettingsT) -> None: if self.trainer.accelerator.is_main_process: self._dataset_and_collator_sanity_check(train_dataset, data_collator) - # self._add_trainer_callbacks(experiment_settings) + self._add_trainer_callbacks(experiment_settings) os.makedirs(self.trainer.args.output_dir, exist_ok=True) self._save_experiment_config( diff --git a/turbo_alignment/pipelines/train/sft_rm.py b/turbo_alignment/pipelines/train/sft_rm.py index 04ba7bd..58924cc 100755 --- a/turbo_alignment/pipelines/train/sft_rm.py +++ b/turbo_alignment/pipelines/train/sft_rm.py @@ -1,5 +1,6 @@ from typing import Callable +import torch from torch import nn from torch.utils.data import Dataset from transformers import ( @@ -39,15 +40,15 @@ def __init__(self, config, model_settings, tokenizer): self.decoder = model.model self.lm_head = model.lm_head - self.rm_head = nn.Linear(self.decoder.norm.weight.shape[0], 1, bias=False) + self.rm_head = nn.Linear(config.hidden_size, 1, bias=False) def forward(self, batch): outputs_w = self.decoder(**batch['inputs_w']).last_hidden_state[0] outputs_l = self.decoder(**batch['inputs_l']).last_hidden_state[0] lm_logits = self.lm_head(outputs_w) - rm_logits_w = self.rm_head(outputs_w[-1]) - rm_logits_l = self.rm_head(outputs_l[-1]) + rm_logits_w = self.rm_head(outputs_w[-1:,])[0] + rm_logits_l = self.rm_head(outputs_l[-1:,])[0] return lm_logits, rm_logits_w, rm_logits_l @@ -99,7 +100,8 @@ def _load_model( tokenizer: PreTrainedTokenizerBase, ) -> nn.Module | PreTrainedModel: config = AutoConfig.from_pretrained(experiment_settings.model_settings.model_path) - return MultiheadModel(config, experiment_settings.model_settings, tokenizer) + model = MultiheadModel(config, experiment_settings.model_settings, tokenizer) + return model @staticmethod def _get_trainer( diff --git a/turbo_alignment/trainers/sft_with_rm.py b/turbo_alignment/trainers/sft_with_rm.py index a957798..3eaea5a 100644 --- a/turbo_alignment/trainers/sft_with_rm.py +++ b/turbo_alignment/trainers/sft_with_rm.py @@ -15,13 +15,15 @@ class SFTwithRMTrainer(MultiGPUCherryPicksTrainer): - def compute_loss(self, model, inputs, return_outputs=False) -> tuple[torch.Tensor, dict[str, Any]] | torch.Tensor: + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None) -> tuple[torch.Tensor, dict[str, Any]] | torch.Tensor: sft_logits, rewards_w, rewards_l = model.forward(inputs) sft_labels = inputs['inputs_w']['input_ids'][0] - loss = -torch.nn.functional.logsigmoid(rewards_w - rewards_l).mean() + torch.nn.functional.cross_entropy( - sft_logits, sft_labels - ) + loss_rm = -torch.nn.functional.logsigmoid(rewards_w - rewards_l).mean() + loss_sft = torch.nn.functional.cross_entropy(sft_logits, sft_labels) + alpha = 0.001 + loss = (1 - alpha) * loss_rm + alpha * loss_sft + if return_outputs: return loss, {'rewards_w': rewards_w, 'rewards_l': rewards_l} return loss @@ -78,11 +80,5 @@ def _save_checkpoint(self, model, trial, metrics=None): ): self.state.best_metric = metric_value self.state.best_model_checkpoint = output_dir - - (output_dir / 'decoder').mkdir(parents=True, exist_ok=True) - - torch.save(model.module.lm_head.state_dict(), output_dir / 'lm_head.pt') - torch.save(model.module.rm_head.state_dict(), output_dir / 'rm_head.pt') - - model.module.decoder.save_pretrained(output_dir / 'decoder') - self.tokenizer.save_pretrained(output_dir / 'decoder') + + return super()._save_checkpoint(model, trial, metrics) \ No newline at end of file