Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

📏 LD-DPO Trainer #63

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 146 additions & 0 deletions tests/fixtures/configs/train/lddpo/base.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
{
"train_dataset_settings": {
"sources": [
{
"name": "rm_preferences_test",
"records_path": "tests/fixtures/datasets/rm/train_preferences.jsonl",
"sample_rate": 1
}
],
"chat_settings":{
"prompt_template": {
"role_tag_mapping": {
"bot": "<bot>",
"user": "<user>",
"system": "<system>"
},
"prefix_template": "<RS>{role}",
"suffix_template": "</RS>"
},
"max_tokens_count": 120
},
"add_labels": true,
"dataset_type": "pair_preferences"
},
"val_dataset_settings": {
"sources": [
{
"name": "rm_preferences_test",
"records_path": "tests/fixtures/datasets/rm/val_preferences.jsonl",
"sample_rate": 1
}
],
"chat_settings":{
"prompt_template": {
"role_tag_mapping": {
"bot": "<bot>",
"user": "<user>",
"system": "<system>"
},
"prefix_template": "<RS>{role}",
"suffix_template": "</RS>"
},
"max_tokens_count": 120
},
"add_labels": true,
"dataset_type": "pair_preferences"
},
"model_settings": {
"model_path": "tests/fixtures/models/llama2_tiny",
"model_type": "causal",
"transformers_settings": {},
"adapter_path": "tests/fixtures/models/llama2_tiny_fine_tuned_with_adapters/trainer",
"is_trainable": true
},
"cherry_pick_settings": {
"generator_transformers_settings": {
"num_beams": 1,
"do_sample": false,
"stop_strings": "</RS>",
"max_new_tokens": 8
},
"custom_generation_settings": {
"skip_special_tokens": false
},
"dataset_settings": {
"sources": [
{
"name": "chat_test",
"records_path": "tests/fixtures/datasets/chat/train_chat.jsonl",
"num_samples": 2
}
],
"prompt_template": {
"role_tag_mapping": {
"bot": "<bot>",
"user": "<user>",
"system": "<system>"
},
"prefix_template": "<RS>{role}",
"suffix_template": "</RS>"
},
"dataset_type": "chat",
"max_tokens_count": 150,
"only_answer_loss": true
},
"metric_settings": [
{
"type": "length",
"parameters": {"need_average": [true]}
},
{
"type": "kl",
"parameters": {
"need_average": [true],
"ref_logits_type": "sft"
}
},
{
"type": "kl",
"parameters": {
"need_average": [true],
"ref_logits_type": "reference"
}
}

]
},
"tokenizer_settings": {},
"special_tokens_settings": {
"bos_token": "<s>",
"eos_token": "</s>",
"pad_token": "<pad>"
},
"trainer_settings": {
"lc_alpha": 0.5,
"evaluation_strategy": "steps",
"per_device_train_batch_size": 2,
"per_device_eval_batch_size": 2,
"gradient_accumulation_steps": 2,
"eval_steps": 4,
"save_steps": 4,
"logging_steps": 1,
"learning_rate": 0.0003,
"num_train_epochs": 2,
"lr_scheduler_type": "cosine",
"warmup_steps": 2,
"fp16": false,
"bf16": false,
"optim": "adamw_torch",
"save_total_limit": 1,
"loss_settings": {
"loss_type": "ipo"
},
"sync_ref_settings": {
"sync_ref_model": true
},
"use_sft_model": true,
"no_cuda": true
},
"logging_settings": {
"project_name": "alignment",
"run_name": "dpo",
"entity": "turbo-alignment"
},
"log_path": "test_dpo_llama_train_output"
}
13 changes: 13 additions & 0 deletions turbo_alignment/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,19 @@ def train_ddpo_entrypoint(
pipelines.TrainDDPOStrategy().run(experiment_settings)


@app.command(name='train_lddpo', help='Run LDDPO pipeline')
def train_lddpo_entrypoint(
experiment_settings_path: Path = typer.Option(
...,
'--experiment_settings_path',
exists=True,
help='Path to experiment config file',
)
) -> None:
experiment_settings = pipeline_settings.LDDPOTrainExperimentSettings.parse_file(experiment_settings_path)
pipelines.TrainLDDPOStrategy().run(experiment_settings)


@app.command(name='train_rm', help='Run RM pipeline')
def train_rm_entrypoint(
experiment_settings_path: Path = typer.Option(
Expand Down
1 change: 1 addition & 0 deletions turbo_alignment/pipelines/train/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .ddpo import TrainDDPOStrategy
from .dpo import TrainDPOStrategy
from .kto import TrainKTOStrategy
from .lddpo import TrainLDDPOStrategy
from .multimodal import TrainMultimodalStrategy
from .rag import TrainRAGStrategy
from .rm import TrainRMStrategy
Expand Down
123 changes: 123 additions & 0 deletions turbo_alignment/pipelines/train/lddpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from typing import Callable

from torch.utils.data import Dataset
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from transformers.data.data_collator import DataCollatorMixin

from turbo_alignment.cherry_picks.chat import ChatCherryPickCallback
from turbo_alignment.common.logging import get_project_logger
from turbo_alignment.common.tf.loaders.model import load_model
from turbo_alignment.constants import TRAINER_LOGS_FOLDER
from turbo_alignment.dataset.chat.chat import InferenceChatDataset
from turbo_alignment.dataset.loader import DatasetLoader
from turbo_alignment.dataset.pair_preferences import PairPreferenceDataCollator
from turbo_alignment.metrics.metric import Metric
from turbo_alignment.metrics.registry import MetricSettingsRegistry
from turbo_alignment.pipelines.train.base import BaseTrainStrategy
from turbo_alignment.settings.datasets.base import DatasetStrategy
from turbo_alignment.settings.pipelines.train.dpo import DPOTrainExperimentSettings
from turbo_alignment.settings.pipelines.train.lddpo import LDDPOTrainExperimentSettings
from turbo_alignment.trainers.lddpo import LDDPOTrainer, LDDPOTrainingArguments

logger = get_project_logger()


class TrainLDDPOStrategy(BaseTrainStrategy[LDDPOTrainExperimentSettings]):
@staticmethod
def _get_data_collator(
experiment_settings: LDDPOTrainExperimentSettings,
tokenizer: PreTrainedTokenizerBase,
**kwargs,
) -> Callable:
return PairPreferenceDataCollator(tokenizer=tokenizer, add_labels=True)

@staticmethod
def _get_cherry_pick_callback(
experiment_settings: LDDPOTrainExperimentSettings,
tokenizer: PreTrainedTokenizerBase,
**kwargs,
) -> ChatCherryPickCallback:
return None
# cherry_pick_settings = experiment_settings.cherry_pick_settings

# cherry_pick_datasets = DatasetLoader[InferenceChatDataset](InferenceChatDataset).load_datasets(
# cherry_pick_settings.dataset_settings, tokenizer=tokenizer, strategy=DatasetStrategy.INFERENCE
# )

# metrics = [
# Metric.by_name(metric.type)(MetricSettingsRegistry.by_name(metric.type)(**metric.parameters))
# for metric in cherry_pick_settings.metric_settings
# ]

# return ChatCherryPickCallback(
# cherry_pick_settings=cherry_pick_settings,
# datasets=cherry_pick_datasets,
# metrics=metrics,
# )

def _dataset_and_collator_sanity_check(self, dataset: Dataset, collator: DataCollatorMixin) -> None:
logger.info(f'Train sample example:\n{dataset[0]}')

logger.info(
'Input-w check: {input_ids}'.format(
input_ids=collator([dataset[0], dataset[1]])['inputs_w']['input_ids'][0]
)
)
logger.info(
'Mask-w check: {mask}'.format(mask=collator([dataset[0], dataset[1]])['inputs_w']['attention_mask'][0])
)
logger.info(
'Input-l check: {input_ids}'.format(
input_ids=collator([dataset[0], dataset[1]])['inputs_l']['input_ids'][0]
)
)
logger.info(
'Mask-l check: {mask}'.format(mask=collator([dataset[0], dataset[1]])['inputs_l']['attention_mask'][0])
)

@staticmethod
def _get_training_args(experiment_settings: LDDPOTrainExperimentSettings) -> LDDPOTrainingArguments:
return LDDPOTrainingArguments(
output_dir=str(experiment_settings.log_path / TRAINER_LOGS_FOLDER),
label_names=[],
remove_unused_columns=False,
**experiment_settings.trainer_settings.dict(),
)

@staticmethod
def _get_trainer(
training_args: LDDPOTrainingArguments,
experiment_settings: DPOTrainExperimentSettings,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizerBase,
train_dataset: Dataset,
val_dataset: Dataset,
data_collator: Callable,
):
model.config.use_cache = not training_args.gradient_checkpointing

extra_args = {}
if experiment_settings.trainer_settings.use_ref_model:
ref_model = load_model(experiment_settings.model_settings, tokenizer)
for _, param in ref_model.named_parameters():
param.requires_grad = False

extra_args['ref_model'] = ref_model

if experiment_settings.trainer_settings.use_sft_model:
sft_model = load_model(experiment_settings.model_settings, tokenizer)
for _, param in sft_model.named_parameters():
param.requires_grad = False

extra_args['sft_model'] = sft_model

return LDDPOTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
callbacks=[],
data_collator=data_collator,
tokenizer=tokenizer,
**extra_args,
)
1 change: 1 addition & 0 deletions turbo_alignment/settings/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from turbo_alignment.settings.pipelines.train.ddpo import DDPOTrainExperimentSettings
from turbo_alignment.settings.pipelines.train.dpo import DPOTrainExperimentSettings
from turbo_alignment.settings.pipelines.train.kto import KTOTrainExperimentSettings
from turbo_alignment.settings.pipelines.train.lddpo import LDDPOTrainExperimentSettings
from turbo_alignment.settings.pipelines.train.multimodal import (
MultimodalTrainExperimentSettings,
)
Expand Down
12 changes: 12 additions & 0 deletions turbo_alignment/settings/pipelines/train/lddpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from turbo_alignment.settings.pipelines.train.dpo import (
DPOTrainerSettings,
DPOTrainExperimentSettings,
)


class LDDPOTrainerSettings(DPOTrainerSettings):
lc_alpha: float


class LDDPOTrainExperimentSettings(DPOTrainExperimentSettings):
trainer_settings: LDDPOTrainerSettings
1 change: 1 addition & 0 deletions turbo_alignment/trainers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .ddpo import DDPOTrainer
from .dpo import DPOTrainer
from .kto import KTOTrainer
from .lddpo import LDDPOTrainer
from .multigpu import MultiGPUCherryPicksTrainer
from .multimodal import MultimodalTrainer
from .rm import RMTrainer
38 changes: 24 additions & 14 deletions turbo_alignment/trainers/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,12 +606,18 @@ def _get_batch_logps(
return (per_token_logps * loss_mask).sum(-1)

def concatenated_forward(
self, model: nn.Module, batch: dict[str, Any]
self,
model: nn.Module,
batch: dict[str, Any],
get_from_dataset=False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None]:
concatenated_batch = concatenated_inputs(batch, device=self.accelerator.device)

precomputed_margins: torch.Tensor | None = concatenated_batch.pop('margin', None)

if get_from_dataset:
return concatenated_batch.pop('ref_chosen_logps'), concatenated_batch.pop('ref_rejected_logps')

all_logits = model(
concatenated_batch['input_ids'],
attention_mask=concatenated_batch['attention_mask'],
Expand All @@ -632,16 +638,19 @@ def concatenated_forward(
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, precomputed_margins

def _get_logps(self, model: nn.Module | None, batch: dict[str, Any]) -> tuple[torch.Tensor, torch.Tensor]:
with torch.no_grad():
if model is not None:
(chosen_logps, rejected_logps, *_) = self.concatenated_forward(model, batch)
else:
with self.accelerator.unwrap_model(self.model).disable_adapter():
(
chosen_logps,
rejected_logps,
*_,
) = self.concatenated_forward(self.model, batch)
if self.ref_model is None:
chosen_logps, rejected_logps = self.concatenated_forward(model, batch, get_from_dataset=True)
else:
with torch.no_grad():
if model is not None:
(chosen_logps, rejected_logps, *_) = self.concatenated_forward(model, batch)
else:
with self.accelerator.unwrap_model(self.model).disable_adapter():
(
chosen_logps,
rejected_logps,
*_,
) = self.concatenated_forward(self.model, batch)

return chosen_logps, rejected_logps

Expand Down Expand Up @@ -780,9 +789,9 @@ def _compute_metrics(
metrics[f'{prefix_name}grad_term'] = (
(self.dpo_loss_registry.beta * F.sigmoid(rejected_rewards - chosen_rewards)).detach().cpu().mean().item()
)
metrics[f'{prefix_name}grad_term_std'] = (
(self.dpo_loss_registry.beta * F.sigmoid(rejected_rewards - chosen_rewards)).detach().cpu().std().item()
)
# metrics[f'{prefix_name}grad_term_std'] = (
# (self.dpo_loss_registry.beta * F.sigmoid(rejected_rewards - chosen_rewards)).detach().cpu().std().item()
# )

return metrics

Expand Down Expand Up @@ -822,6 +831,7 @@ def compute_loss(
model: PreTrainedModel | nn.Module,
inputs: dict[str, torch.Tensor | Any],
return_outputs=False,
num_items_in_batch=None,
) -> torch.Tensor | tuple[torch.Tensor, dict[str, float]]:
loss, metrics = self.get_batch_metrics(model, inputs, train_eval='train')

Expand Down
Loading
Loading