diff --git a/turbo_alignment/pipelines/train/lddpo.py b/turbo_alignment/pipelines/train/lddpo.py index 611477f..7eb3d51 100644 --- a/turbo_alignment/pipelines/train/lddpo.py +++ b/turbo_alignment/pipelines/train/lddpo.py @@ -37,22 +37,23 @@ def _get_cherry_pick_callback( tokenizer: PreTrainedTokenizerBase, **kwargs, ) -> ChatCherryPickCallback: - cherry_pick_settings = experiment_settings.cherry_pick_settings - - cherry_pick_datasets = DatasetLoader[InferenceChatDataset](InferenceChatDataset).load_datasets( - cherry_pick_settings.dataset_settings, tokenizer=tokenizer, strategy=DatasetStrategy.INFERENCE - ) - - metrics = [ - Metric.by_name(metric.type)(MetricSettingsRegistry.by_name(metric.type)(**metric.parameters)) - for metric in cherry_pick_settings.metric_settings - ] - - return ChatCherryPickCallback( - cherry_pick_settings=cherry_pick_settings, - datasets=cherry_pick_datasets, - metrics=metrics, - ) + return None + # cherry_pick_settings = experiment_settings.cherry_pick_settings + + # cherry_pick_datasets = DatasetLoader[InferenceChatDataset](InferenceChatDataset).load_datasets( + # cherry_pick_settings.dataset_settings, tokenizer=tokenizer, strategy=DatasetStrategy.INFERENCE + # ) + + # metrics = [ + # Metric.by_name(metric.type)(MetricSettingsRegistry.by_name(metric.type)(**metric.parameters)) + # for metric in cherry_pick_settings.metric_settings + # ] + + # return ChatCherryPickCallback( + # cherry_pick_settings=cherry_pick_settings, + # datasets=cherry_pick_datasets, + # metrics=metrics, + # ) def _dataset_and_collator_sanity_check(self, dataset: Dataset, collator: DataCollatorMixin) -> None: logger.info(f'Train sample example:\n{dataset[0]}') diff --git a/turbo_alignment/trainers/dpo.py b/turbo_alignment/trainers/dpo.py index ef3288f..94dd7b6 100755 --- a/turbo_alignment/trainers/dpo.py +++ b/turbo_alignment/trainers/dpo.py @@ -606,12 +606,18 @@ def _get_batch_logps( return (per_token_logps * loss_mask).sum(-1) def concatenated_forward( - self, model: nn.Module, batch: dict[str, Any] + self, + model: nn.Module, + batch: dict[str, Any], + get_from_dataset=False, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None]: concatenated_batch = concatenated_inputs(batch, device=self.accelerator.device) precomputed_margins: torch.Tensor | None = concatenated_batch.pop('margin', None) + if get_from_dataset: + return concatenated_batch.pop('ref_chosen_logps'), concatenated_batch.pop('ref_rejected_logps') + all_logits = model( concatenated_batch['input_ids'], attention_mask=concatenated_batch['attention_mask'], @@ -632,16 +638,19 @@ def concatenated_forward( return chosen_logps, rejected_logps, chosen_logits, rejected_logits, precomputed_margins def _get_logps(self, model: nn.Module | None, batch: dict[str, Any]) -> tuple[torch.Tensor, torch.Tensor]: - with torch.no_grad(): - if model is not None: - (chosen_logps, rejected_logps, *_) = self.concatenated_forward(model, batch) - else: - with self.accelerator.unwrap_model(self.model).disable_adapter(): - ( - chosen_logps, - rejected_logps, - *_, - ) = self.concatenated_forward(self.model, batch) + if self.ref_model is None: + chosen_logps, rejected_logps = self.concatenated_forward(model, batch, get_from_dataset=True) + else: + with torch.no_grad(): + if model is not None: + (chosen_logps, rejected_logps, *_) = self.concatenated_forward(model, batch) + else: + with self.accelerator.unwrap_model(self.model).disable_adapter(): + ( + chosen_logps, + rejected_logps, + *_, + ) = self.concatenated_forward(self.model, batch) return chosen_logps, rejected_logps @@ -780,9 +789,9 @@ def _compute_metrics( metrics[f'{prefix_name}grad_term'] = ( (self.dpo_loss_registry.beta * F.sigmoid(rejected_rewards - chosen_rewards)).detach().cpu().mean().item() ) - metrics[f'{prefix_name}grad_term_std'] = ( - (self.dpo_loss_registry.beta * F.sigmoid(rejected_rewards - chosen_rewards)).detach().cpu().std().item() - ) + # metrics[f'{prefix_name}grad_term_std'] = ( + # (self.dpo_loss_registry.beta * F.sigmoid(rejected_rewards - chosen_rewards)).detach().cpu().std().item() + # ) return metrics @@ -822,6 +831,7 @@ def compute_loss( model: PreTrainedModel | nn.Module, inputs: dict[str, torch.Tensor | Any], return_outputs=False, + num_items_in_batch=None, ) -> torch.Tensor | tuple[torch.Tensor, dict[str, float]]: loss, metrics = self.get_batch_metrics(model, inputs, train_eval='train')