-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
d.taranets
committed
Oct 14, 2024
1 parent
2b9f206
commit 4f85480
Showing
4 changed files
with
242 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
from typing import Callable | ||
import torch | ||
|
||
import numpy as np | ||
from torch import nn | ||
from torch.utils.data import Dataset | ||
from transformers import PreTrainedModel, PreTrainedTokenizerBase, TrainingArguments, AutoModel, AutoModelForCausalLM, GenerationMixin, AutoConfig | ||
from transformers.data.data_collator import DataCollatorMixin | ||
|
||
from turbo_alignment.cherry_picks.rm import MultiHeadCherryPickCallback | ||
from turbo_alignment.common.logging import get_project_logger | ||
from turbo_alignment.constants import TRAINER_LOGS_FOLDER | ||
from turbo_alignment.dataset.loader import DatasetLoader | ||
from turbo_alignment.dataset.pair_preferences import ( | ||
PairPreferenceDataCollator, | ||
PairPreferenceDataset, | ||
) | ||
from turbo_alignment.common.tf.loaders.model.model import load_model | ||
from turbo_alignment.metrics.metric import Metric | ||
from turbo_alignment.metrics.registry import MetricSettingsRegistry | ||
from turbo_alignment.metrics.reward import compute_metrics as compute_rm_metrics | ||
from turbo_alignment.pipelines.train.base import BaseTrainStrategy | ||
from turbo_alignment.settings.datasets import DatasetStrategy | ||
from turbo_alignment.settings.pipelines import RMTrainExperimentSettings | ||
from turbo_alignment.trainers.sft_with_rm import SFTwithRMTrainer | ||
from turbo_alignment.trainers.utils import concatenated_inputs | ||
|
||
logger = get_project_logger() | ||
|
||
|
||
class MultiheadModel(PreTrainedModel, GenerationMixin): | ||
def __init__(self, config, model_settings, tokenizer): | ||
super().__init__(config) | ||
|
||
self.decoder = load_model(model_settings=model_settings, tokenizer=tokenizer) | ||
|
||
self.lm_head = nn.Linear(self.decoder.norm.weight.shape[0], len(tokenizer), bias=False) | ||
self.rm_head = nn.Linear(self.decoder.norm.weight.shape[0], 1, bias=False) | ||
|
||
reward_token_ids = tokenizer.encode('<reward>', add_special_tokens=False) | ||
if len(reward_token_ids) != 1: | ||
raise ValueError('<reward> token us not found in the tokenizer') | ||
|
||
self.reward_token_ids = reward_token_ids[0] | ||
|
||
def forward(self, batch): | ||
|
||
outputs_w = self.decoder(**batch['inputs_w']).last_hidden_state[0] | ||
outputs_l = self.decoder(**batch['inputs_l']).last_hidden_state[0] | ||
|
||
|
||
reward_token_pos_w = np.where(batch['inputs_w']['input_ids'][0].cpu() == self.reward_token_ids)[0] | ||
reward_token_pos_l = np.where(batch['inputs_l']['input_ids'][0].cpu() == self.reward_token_ids)[0] | ||
|
||
if len(reward_token_pos_w) != 1 or len(reward_token_pos_l) != 1: | ||
raise ValueError('More than one <reward> token detected in replica') | ||
|
||
outputs_w_1 = outputs_w[:reward_token_pos_w[0]] | ||
outputs_w_2 = outputs_w[reward_token_pos_w[0]+1:] | ||
outputs_w_cat = torch.cat((outputs_w_1, outputs_w_2), dim=0) | ||
|
||
lm_logits = self.lm_head(outputs_w_cat) | ||
rm_logits_w = self.rm_head(outputs_w[reward_token_pos_w[0]]) | ||
rm_logits_l = self.rm_head(outputs_l[reward_token_pos_l[0]]) | ||
|
||
return lm_logits, rm_logits_w, rm_logits_l, reward_token_pos_w | ||
|
||
|
||
class TrainMultiheadStrategy(BaseTrainStrategy[RMTrainExperimentSettings]): | ||
@staticmethod | ||
def _get_data_collator( | ||
experiment_settings: RMTrainExperimentSettings, | ||
tokenizer: PreTrainedTokenizerBase, | ||
**_kwargs, | ||
) -> Callable: | ||
return PairPreferenceDataCollator(tokenizer=tokenizer, add_labels=False) | ||
|
||
@staticmethod | ||
def _get_cherry_pick_callback( | ||
experiment_settings: RMTrainExperimentSettings, | ||
tokenizer: PreTrainedTokenizerBase, | ||
**_kwargs, | ||
) -> MultiHeadCherryPickCallback: | ||
cherry_pick_settings = experiment_settings.cherry_pick_settings | ||
|
||
cherry_pick_datasets = DatasetLoader[PairPreferenceDataset](PairPreferenceDataset).load_datasets( | ||
cherry_pick_settings.dataset_settings, tokenizer=tokenizer, strategy=DatasetStrategy.TRAIN | ||
) | ||
|
||
metrics = [ | ||
Metric.by_name(metric.type)(MetricSettingsRegistry.by_name(metric.type)(**metric.parameters)) | ||
for metric in cherry_pick_settings.metric_settings | ||
] | ||
|
||
return MultiHeadCherryPickCallback( | ||
cherry_pick_settings=cherry_pick_settings, | ||
datasets=cherry_pick_datasets, | ||
metrics=metrics, | ||
) | ||
|
||
@staticmethod | ||
def _get_training_args(experiment_settings: RMTrainExperimentSettings) -> TrainingArguments: | ||
return TrainingArguments( | ||
output_dir=str(experiment_settings.log_path / TRAINER_LOGS_FOLDER), | ||
label_names=[], | ||
remove_unused_columns=False, | ||
**experiment_settings.trainer_settings.dict(), | ||
) | ||
|
||
@staticmethod | ||
def _load_model( | ||
experiment_settings: RMTrainExperimentSettings, | ||
tokenizer: PreTrainedTokenizerBase, | ||
) -> nn.Module | PreTrainedModel: | ||
config = AutoConfig.from_pretrained(experiment_settings.model_settings.model_path) | ||
return MultiheadModel(config, experiment_settings.model_settings, tokenizer) | ||
|
||
|
||
@staticmethod | ||
def _get_trainer( | ||
training_args: TrainingArguments, | ||
experiment_settings: RMTrainExperimentSettings, | ||
model: PreTrainedModel, | ||
tokenizer: PreTrainedTokenizerBase, | ||
train_dataset: Dataset, | ||
val_dataset: Dataset, | ||
data_collator: DataCollatorMixin, | ||
**_kwargs, | ||
): | ||
return SFTwithRMTrainer( | ||
model=model, | ||
tokenizer=tokenizer, | ||
args=training_args, | ||
train_dataset=train_dataset, | ||
eval_dataset=val_dataset, | ||
data_collator=data_collator, | ||
compute_metrics=compute_rm_metrics, | ||
callbacks=[], | ||
) | ||
|
||
def _dataset_and_collator_sanity_check(self, dataset: Dataset, collator: DataCollatorMixin) -> None: | ||
logger.info(f'Train sample input_ids:\n{dataset[0]}') | ||
logger.info(f'Train sample example:\n{self.tokenizer.decode(dataset[0]["inputs_w"]["input_ids"])}') | ||
logger.info(f'Train sample example:\n{self.tokenizer.decode(dataset[0]["inputs_l"]["input_ids"])}') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
from typing import Any | ||
import os | ||
|
||
import torch | ||
from torch import nn | ||
from pathlib import Path | ||
from transformers import PreTrainedModel, Trainer | ||
from transformers.trainer_pt_utils import nested_detach | ||
from transformers.utils import logging | ||
|
||
from turbo_alignment.trainers.multigpu import MultiGPUCherryPicksTrainer | ||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR | ||
|
||
logger = logging.get_logger(__name__) | ||
|
||
|
||
class SFTwithRMTrainer(MultiGPUCherryPicksTrainer): | ||
def compute_loss(self, model, inputs, return_outputs=False) -> tuple[torch.Tensor, dict[str, Any]] | torch.Tensor: | ||
|
||
sft_logits, rewards_w, rewards_l, reward_token_pos_w = model.forward(inputs) | ||
|
||
sft_logits = sft_logits.view(-1, sft_logits.size(-1)) | ||
sft_labels = inputs['inputs_w']['input_ids'] | ||
|
||
# print(f'INPUTS TEXT SHAPE: {sft_labels.shape}') | ||
# print(f'INPUTS TEXT: {sft_labels}') | ||
|
||
sft_labels_1 = sft_labels.view(-1)[:reward_token_pos_w[0]] | ||
sft_labels_2 = sft_labels.view(-1)[reward_token_pos_w[0]+1:] | ||
sft_labels_cat = torch.cat((sft_labels_1, sft_labels_2), dim=0) | ||
|
||
loss = -torch.nn.functional.logsigmoid(rewards_w - rewards_l).mean() + torch.nn.functional.cross_entropy(sft_logits, sft_labels_cat) | ||
if return_outputs: | ||
return loss, {'rewards_w': rewards_w, 'rewards_l': rewards_l} | ||
return loss | ||
|
||
def prediction_step( | ||
self, | ||
model: PreTrainedModel | nn.Module, | ||
inputs: dict[str, dict[str, torch.Tensor]], | ||
prediction_loss_only: bool, | ||
ignore_keys: list[str] | None, | ||
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: | ||
|
||
inputs = self._prepare_inputs(inputs) | ||
|
||
if ignore_keys is None: | ||
if hasattr(self.model, 'config'): | ||
ignore_keys = getattr(self.model.config, 'keys_to_ignore_at_inference', []) | ||
else: | ||
ignore_keys = [] | ||
|
||
with torch.no_grad(): | ||
loss, logits_dict = self.compute_loss(model, inputs, return_outputs=True) | ||
|
||
if prediction_loss_only: | ||
return (loss, None, None) | ||
|
||
loss = loss.detach() | ||
logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys) | ||
logits = nested_detach(logits) | ||
|
||
logits = torch.stack(logits).T | ||
|
||
labels = logits[:, 0] > logits[:, 1] | ||
|
||
labels = labels.long() | ||
|
||
return loss, logits, labels | ||
|
||
|
||
def _save_checkpoint(self, model, trial, metrics=None): | ||
logger.info('Running custom _save_checkpoint') | ||
checkpoint_folder = f'{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}' | ||
run_dir = self._get_output_dir(trial=trial) | ||
output_dir = Path(os.path.join(run_dir, checkpoint_folder)) | ||
|
||
(output_dir / 'decoder').mkdir(parents=True, exist_ok=True) | ||
|
||
torch.save(model.module.lm_head.state_dict(), output_dir / 'lm_head.pt') | ||
torch.save(model.module.rm_head.state_dict(), output_dir / 'rm_head.pt') | ||
|
||
model.module.decoder.save_pretrained(output_dir / 'decoder') | ||
self.tokenizer.save_pretrained(output_dir / 'decoder') |