Skip to content

Commit

Permalink
ld-dpo trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
lmeribal committed Nov 27, 2024
1 parent 809eb0e commit 20b97d1
Show file tree
Hide file tree
Showing 8 changed files with 406 additions and 0 deletions.
145 changes: 145 additions & 0 deletions tests/fixtures/configs/train/lddpo/base.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
{
"train_dataset_settings": {
"sources": [
{
"name": "rm_preferences_test",
"records_path": "tests/fixtures/datasets/rm/train_preferences.jsonl",
"sample_rate": 1
}
],
"chat_settings":{
"prompt_template": {
"role_tag_mapping": {
"bot": "<bot>",
"user": "<user>",
"system": "<system>"
},
"prefix_template": "<RS>{role}",
"suffix_template": "</RS>"
},
"max_tokens_count": 120
},
"add_labels": true,
"dataset_type": "pair_preferences"
},
"val_dataset_settings": {
"sources": [
{
"name": "rm_preferences_test",
"records_path": "tests/fixtures/datasets/rm/val_preferences.jsonl",
"sample_rate": 1
}
],
"chat_settings":{
"prompt_template": {
"role_tag_mapping": {
"bot": "<bot>",
"user": "<user>",
"system": "<system>"
},
"prefix_template": "<RS>{role}",
"suffix_template": "</RS>"
},
"max_tokens_count": 120
},
"add_labels": true,
"dataset_type": "pair_preferences"
},
"model_settings": {
"model_path": "tests/fixtures/models/llama2_tiny",
"model_type": "causal",
"transformers_settings": {},
"adapter_path": "tests/fixtures/models/llama2_tiny_fine_tuned_with_adapters/trainer",
"is_trainable": true
},
"cherry_pick_settings": {
"generator_transformers_settings": {
"num_beams": 1,
"do_sample": false,
"stop_strings": "</RS>",
"max_new_tokens": 8
},
"custom_generation_settings": {
"skip_special_tokens": false
},
"dataset_settings": {
"sources": [
{
"name": "chat_test",
"records_path": "tests/fixtures/datasets/chat/train_chat.jsonl",
"num_samples": 2
}
],
"prompt_template": {
"role_tag_mapping": {
"bot": "<bot>",
"user": "<user>",
"system": "<system>"
},
"prefix_template": "<RS>{role}",
"suffix_template": "</RS>"
},
"dataset_type": "chat",
"max_tokens_count": 150,
"only_answer_loss": true
},
"metric_settings": [
{
"type": "length",
"parameters": {"need_average": [true]}
},
{
"type": "kl",
"parameters": {
"need_average": [true],
"ref_logits_type": "sft"
}
},
{
"type": "kl",
"parameters": {
"need_average": [true],
"ref_logits_type": "reference"
}
}

]
},
"tokenizer_settings": {},
"special_tokens_settings": {
"bos_token": "<s>",
"eos_token": "</s>",
"pad_token": "<pad>"
},
"trainer_settings": {
"evaluation_strategy": "steps",
"per_device_train_batch_size": 2,
"per_device_eval_batch_size": 2,
"gradient_accumulation_steps": 2,
"eval_steps": 4,
"save_steps": 4,
"logging_steps": 1,
"learning_rate": 0.0003,
"num_train_epochs": 2,
"lr_scheduler_type": "cosine",
"warmup_steps": 2,
"fp16": false,
"bf16": false,
"optim": "adamw_torch",
"save_total_limit": 1,
"loss_settings": {
"loss_type": "ipo"
},
"sync_ref_settings": {
"sync_ref_model": true
},
"use_sft_model": true,
"no_cuda": true
},
"logging_settings": {
"project_name": "alignment",
"run_name": "dpo",
"entity": "turbo-alignment"
},
"log_path": "test_dpo_llama_train_output"
}
13 changes: 13 additions & 0 deletions turbo_alignment/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,19 @@ def train_ddpo_entrypoint(
pipelines.TrainDDPOStrategy().run(experiment_settings)


@app.command(name='train_lddpo', help='Run LDDPO pipeline')
def train_lddpo_entrypoint(
experiment_settings_path: Path = typer.Option(
...,
'--experiment_settings_path',
exists=True,
help='Path to experiment config file',
)
) -> None:
experiment_settings = pipeline_settings.LDDPOTrainExperimentSettings.parse_file(experiment_settings_path)
pipelines.TrainLDDPOStrategy().run(experiment_settings)


@app.command(name='train_rm', help='Run RM pipeline')
def train_rm_entrypoint(
experiment_settings_path: Path = typer.Option(
Expand Down
1 change: 1 addition & 0 deletions turbo_alignment/pipelines/train/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .ddpo import TrainDDPOStrategy
from .dpo import TrainDPOStrategy
from .kto import TrainKTOStrategy
from .lddpo import TrainLDDPOStrategy
from .multimodal import TrainMultimodalStrategy
from .rag import TrainRAGStrategy
from .rm import TrainRMStrategy
Expand Down
122 changes: 122 additions & 0 deletions turbo_alignment/pipelines/train/lddpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from typing import Callable

from torch.utils.data import Dataset
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from transformers.data.data_collator import DataCollatorMixin

from turbo_alignment.cherry_picks.chat import ChatCherryPickCallback
from turbo_alignment.common.logging import get_project_logger
from turbo_alignment.common.tf.loaders.model import load_model
from turbo_alignment.constants import TRAINER_LOGS_FOLDER
from turbo_alignment.dataset.chat.chat import InferenceChatDataset
from turbo_alignment.dataset.loader import DatasetLoader
from turbo_alignment.dataset.pair_preferences import PairPreferenceDataCollator
from turbo_alignment.metrics.metric import Metric
from turbo_alignment.metrics.registry import MetricSettingsRegistry
from turbo_alignment.pipelines.train.base import BaseTrainStrategy
from turbo_alignment.settings.datasets.base import DatasetStrategy
from turbo_alignment.settings.pipelines.train.dpo import DPOTrainExperimentSettings
from turbo_alignment.settings.pipelines.train.lddpo import LDDPOTrainExperimentSettings
from turbo_alignment.trainers.lddpo import LDDPOTrainer, LDDPOTrainingArguments

logger = get_project_logger()


class TrainLDDPOStrategy(BaseTrainStrategy[LDDPOTrainExperimentSettings]):
@staticmethod
def _get_data_collator(
experiment_settings: LDDPOTrainExperimentSettings,
tokenizer: PreTrainedTokenizerBase,
**kwargs,
) -> Callable:
return PairPreferenceDataCollator(tokenizer=tokenizer, add_labels=True)

@staticmethod
def _get_cherry_pick_callback(
experiment_settings: LDDPOTrainExperimentSettings,
tokenizer: PreTrainedTokenizerBase,
**kwargs,
) -> ChatCherryPickCallback:
cherry_pick_settings = experiment_settings.cherry_pick_settings

cherry_pick_datasets = DatasetLoader[InferenceChatDataset](InferenceChatDataset).load_datasets(
cherry_pick_settings.dataset_settings, tokenizer=tokenizer, strategy=DatasetStrategy.INFERENCE
)

metrics = [
Metric.by_name(metric.type)(MetricSettingsRegistry.by_name(metric.type)(**metric.parameters))
for metric in cherry_pick_settings.metric_settings
]

return ChatCherryPickCallback(
cherry_pick_settings=cherry_pick_settings,
datasets=cherry_pick_datasets,
metrics=metrics,
)

def _dataset_and_collator_sanity_check(self, dataset: Dataset, collator: DataCollatorMixin) -> None:
logger.info(f'Train sample example:\n{dataset[0]}')

logger.info(
'Input-w check: {input_ids}'.format(
input_ids=collator([dataset[0], dataset[1]])['inputs_w']['input_ids'][0]
)
)
logger.info(
'Mask-w check: {mask}'.format(mask=collator([dataset[0], dataset[1]])['inputs_w']['attention_mask'][0])
)
logger.info(
'Input-l check: {input_ids}'.format(
input_ids=collator([dataset[0], dataset[1]])['inputs_l']['input_ids'][0]
)
)
logger.info(
'Mask-l check: {mask}'.format(mask=collator([dataset[0], dataset[1]])['inputs_l']['attention_mask'][0])
)

@staticmethod
def _get_training_args(experiment_settings: LDDPOTrainExperimentSettings) -> LDDPOTrainingArguments:
return LDDPOTrainingArguments(
output_dir=str(experiment_settings.log_path / TRAINER_LOGS_FOLDER),
label_names=[],
remove_unused_columns=False,
**experiment_settings.trainer_settings.dict(),
)

@staticmethod
def _get_trainer(
training_args: LDDPOTrainingArguments,
experiment_settings: DPOTrainExperimentSettings,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizerBase,
train_dataset: Dataset,
val_dataset: Dataset,
data_collator: Callable,
):
model.config.use_cache = not training_args.gradient_checkpointing

extra_args = {}
if experiment_settings.trainer_settings.use_ref_model:
ref_model = load_model(experiment_settings.model_settings, tokenizer)
for _, param in ref_model.named_parameters():
param.requires_grad = False

extra_args['ref_model'] = ref_model

if experiment_settings.trainer_settings.use_sft_model:
sft_model = load_model(experiment_settings.model_settings, tokenizer)
for _, param in sft_model.named_parameters():
param.requires_grad = False

extra_args['sft_model'] = sft_model

return LDDPOTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
callbacks=[],
data_collator=data_collator,
tokenizer=tokenizer,
**extra_args,
)
1 change: 1 addition & 0 deletions turbo_alignment/settings/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from turbo_alignment.settings.pipelines.train.ddpo import DDPOTrainExperimentSettings
from turbo_alignment.settings.pipelines.train.dpo import DPOTrainExperimentSettings
from turbo_alignment.settings.pipelines.train.kto import KTOTrainExperimentSettings
from turbo_alignment.settings.pipelines.train.lddpo import LDDPOTrainExperimentSettings
from turbo_alignment.settings.pipelines.train.multimodal import (
MultimodalTrainExperimentSettings,
)
Expand Down
12 changes: 12 additions & 0 deletions turbo_alignment/settings/pipelines/train/lddpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from turbo_alignment.settings.pipelines.train.dpo import (
DPOTrainerSettings,
DPOTrainExperimentSettings,
)


class LDDPOTrainerSettings(DPOTrainerSettings):
alpha: float = 1.0


class LDDPOTrainExperimentSettings(DPOTrainExperimentSettings):
trainer_settings: LDDPOTrainerSettings
1 change: 1 addition & 0 deletions turbo_alignment/trainers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .ddpo import DDPOTrainer
from .dpo import DPOTrainer
from .kto import KTOTrainer
from .lddpo import LDDPOTrainer
from .multigpu import MultiGPUCherryPicksTrainer
from .multimodal import MultimodalTrainer
from .rm import RMTrainer
Loading

0 comments on commit 20b97d1

Please sign in to comment.