Skip to content

Commit

Permalink
in progress
Browse files Browse the repository at this point in the history
  • Loading branch information
Elisei Rykov committed Dec 11, 2024
1 parent bb61428 commit 61ea543
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 30 deletions.
33 changes: 17 additions & 16 deletions turbo_alignment/pipelines/train/lddpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,22 +37,23 @@ def _get_cherry_pick_callback(
tokenizer: PreTrainedTokenizerBase,
**kwargs,
) -> ChatCherryPickCallback:
cherry_pick_settings = experiment_settings.cherry_pick_settings

cherry_pick_datasets = DatasetLoader[InferenceChatDataset](InferenceChatDataset).load_datasets(
cherry_pick_settings.dataset_settings, tokenizer=tokenizer, strategy=DatasetStrategy.INFERENCE
)

metrics = [
Metric.by_name(metric.type)(MetricSettingsRegistry.by_name(metric.type)(**metric.parameters))
for metric in cherry_pick_settings.metric_settings
]

return ChatCherryPickCallback(
cherry_pick_settings=cherry_pick_settings,
datasets=cherry_pick_datasets,
metrics=metrics,
)
return None
# cherry_pick_settings = experiment_settings.cherry_pick_settings

# cherry_pick_datasets = DatasetLoader[InferenceChatDataset](InferenceChatDataset).load_datasets(
# cherry_pick_settings.dataset_settings, tokenizer=tokenizer, strategy=DatasetStrategy.INFERENCE
# )

# metrics = [
# Metric.by_name(metric.type)(MetricSettingsRegistry.by_name(metric.type)(**metric.parameters))
# for metric in cherry_pick_settings.metric_settings
# ]

# return ChatCherryPickCallback(
# cherry_pick_settings=cherry_pick_settings,
# datasets=cherry_pick_datasets,
# metrics=metrics,
# )

def _dataset_and_collator_sanity_check(self, dataset: Dataset, collator: DataCollatorMixin) -> None:
logger.info(f'Train sample example:\n{dataset[0]}')
Expand Down
38 changes: 24 additions & 14 deletions turbo_alignment/trainers/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,12 +606,18 @@ def _get_batch_logps(
return (per_token_logps * loss_mask).sum(-1)

def concatenated_forward(
self, model: nn.Module, batch: dict[str, Any]
self,
model: nn.Module,
batch: dict[str, Any],
get_from_dataset=False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None]:
concatenated_batch = concatenated_inputs(batch, device=self.accelerator.device)

precomputed_margins: torch.Tensor | None = concatenated_batch.pop('margin', None)

if get_from_dataset:
return concatenated_batch.pop('ref_chosen_logps'), concatenated_batch.pop('ref_rejected_logps')

all_logits = model(
concatenated_batch['input_ids'],
attention_mask=concatenated_batch['attention_mask'],
Expand All @@ -632,16 +638,19 @@ def concatenated_forward(
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, precomputed_margins

def _get_logps(self, model: nn.Module | None, batch: dict[str, Any]) -> tuple[torch.Tensor, torch.Tensor]:
with torch.no_grad():
if model is not None:
(chosen_logps, rejected_logps, *_) = self.concatenated_forward(model, batch)
else:
with self.accelerator.unwrap_model(self.model).disable_adapter():
(
chosen_logps,
rejected_logps,
*_,
) = self.concatenated_forward(self.model, batch)
if self.ref_model is None:
chosen_logps, rejected_logps = self.concatenated_forward(model, batch, get_from_dataset=True)
else:
with torch.no_grad():
if model is not None:
(chosen_logps, rejected_logps, *_) = self.concatenated_forward(model, batch)
else:
with self.accelerator.unwrap_model(self.model).disable_adapter():
(
chosen_logps,
rejected_logps,
*_,
) = self.concatenated_forward(self.model, batch)

return chosen_logps, rejected_logps

Expand Down Expand Up @@ -780,9 +789,9 @@ def _compute_metrics(
metrics[f'{prefix_name}grad_term'] = (
(self.dpo_loss_registry.beta * F.sigmoid(rejected_rewards - chosen_rewards)).detach().cpu().mean().item()
)
metrics[f'{prefix_name}grad_term_std'] = (
(self.dpo_loss_registry.beta * F.sigmoid(rejected_rewards - chosen_rewards)).detach().cpu().std().item()
)
# metrics[f'{prefix_name}grad_term_std'] = (
# (self.dpo_loss_registry.beta * F.sigmoid(rejected_rewards - chosen_rewards)).detach().cpu().std().item()
# )

return metrics

Expand Down Expand Up @@ -822,6 +831,7 @@ def compute_loss(
model: PreTrainedModel | nn.Module,
inputs: dict[str, torch.Tensor | Any],
return_outputs=False,
num_items_in_batch=None,
) -> torch.Tensor | tuple[torch.Tensor, dict[str, float]]:
loss, metrics = self.get_batch_metrics(model, inputs, train_eval='train')

Expand Down

0 comments on commit 61ea543

Please sign in to comment.