-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
lmeribal
committed
Nov 27, 2024
1 parent
809eb0e
commit 20b97d1
Showing
8 changed files
with
406 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
{ | ||
"train_dataset_settings": { | ||
"sources": [ | ||
{ | ||
"name": "rm_preferences_test", | ||
"records_path": "tests/fixtures/datasets/rm/train_preferences.jsonl", | ||
"sample_rate": 1 | ||
} | ||
], | ||
"chat_settings":{ | ||
"prompt_template": { | ||
"role_tag_mapping": { | ||
"bot": "<bot>", | ||
"user": "<user>", | ||
"system": "<system>" | ||
}, | ||
"prefix_template": "<RS>{role}", | ||
"suffix_template": "</RS>" | ||
}, | ||
"max_tokens_count": 120 | ||
}, | ||
"add_labels": true, | ||
"dataset_type": "pair_preferences" | ||
}, | ||
"val_dataset_settings": { | ||
"sources": [ | ||
{ | ||
"name": "rm_preferences_test", | ||
"records_path": "tests/fixtures/datasets/rm/val_preferences.jsonl", | ||
"sample_rate": 1 | ||
} | ||
], | ||
"chat_settings":{ | ||
"prompt_template": { | ||
"role_tag_mapping": { | ||
"bot": "<bot>", | ||
"user": "<user>", | ||
"system": "<system>" | ||
}, | ||
"prefix_template": "<RS>{role}", | ||
"suffix_template": "</RS>" | ||
}, | ||
"max_tokens_count": 120 | ||
}, | ||
"add_labels": true, | ||
"dataset_type": "pair_preferences" | ||
}, | ||
"model_settings": { | ||
"model_path": "tests/fixtures/models/llama2_tiny", | ||
"model_type": "causal", | ||
"transformers_settings": {}, | ||
"adapter_path": "tests/fixtures/models/llama2_tiny_fine_tuned_with_adapters/trainer", | ||
"is_trainable": true | ||
}, | ||
"cherry_pick_settings": { | ||
"generator_transformers_settings": { | ||
"num_beams": 1, | ||
"do_sample": false, | ||
"stop_strings": "</RS>", | ||
"max_new_tokens": 8 | ||
}, | ||
"custom_generation_settings": { | ||
"skip_special_tokens": false | ||
}, | ||
"dataset_settings": { | ||
"sources": [ | ||
{ | ||
"name": "chat_test", | ||
"records_path": "tests/fixtures/datasets/chat/train_chat.jsonl", | ||
"num_samples": 2 | ||
} | ||
], | ||
"prompt_template": { | ||
"role_tag_mapping": { | ||
"bot": "<bot>", | ||
"user": "<user>", | ||
"system": "<system>" | ||
}, | ||
"prefix_template": "<RS>{role}", | ||
"suffix_template": "</RS>" | ||
}, | ||
"dataset_type": "chat", | ||
"max_tokens_count": 150, | ||
"only_answer_loss": true | ||
}, | ||
"metric_settings": [ | ||
{ | ||
"type": "length", | ||
"parameters": {"need_average": [true]} | ||
}, | ||
{ | ||
"type": "kl", | ||
"parameters": { | ||
"need_average": [true], | ||
"ref_logits_type": "sft" | ||
} | ||
}, | ||
{ | ||
"type": "kl", | ||
"parameters": { | ||
"need_average": [true], | ||
"ref_logits_type": "reference" | ||
} | ||
} | ||
|
||
] | ||
}, | ||
"tokenizer_settings": {}, | ||
"special_tokens_settings": { | ||
"bos_token": "<s>", | ||
"eos_token": "</s>", | ||
"pad_token": "<pad>" | ||
}, | ||
"trainer_settings": { | ||
"evaluation_strategy": "steps", | ||
"per_device_train_batch_size": 2, | ||
"per_device_eval_batch_size": 2, | ||
"gradient_accumulation_steps": 2, | ||
"eval_steps": 4, | ||
"save_steps": 4, | ||
"logging_steps": 1, | ||
"learning_rate": 0.0003, | ||
"num_train_epochs": 2, | ||
"lr_scheduler_type": "cosine", | ||
"warmup_steps": 2, | ||
"fp16": false, | ||
"bf16": false, | ||
"optim": "adamw_torch", | ||
"save_total_limit": 1, | ||
"loss_settings": { | ||
"loss_type": "ipo" | ||
}, | ||
"sync_ref_settings": { | ||
"sync_ref_model": true | ||
}, | ||
"use_sft_model": true, | ||
"no_cuda": true | ||
}, | ||
"logging_settings": { | ||
"project_name": "alignment", | ||
"run_name": "dpo", | ||
"entity": "turbo-alignment" | ||
}, | ||
"log_path": "test_dpo_llama_train_output" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
from typing import Callable | ||
|
||
from torch.utils.data import Dataset | ||
from transformers import PreTrainedModel, PreTrainedTokenizerBase | ||
from transformers.data.data_collator import DataCollatorMixin | ||
|
||
from turbo_alignment.cherry_picks.chat import ChatCherryPickCallback | ||
from turbo_alignment.common.logging import get_project_logger | ||
from turbo_alignment.common.tf.loaders.model import load_model | ||
from turbo_alignment.constants import TRAINER_LOGS_FOLDER | ||
from turbo_alignment.dataset.chat.chat import InferenceChatDataset | ||
from turbo_alignment.dataset.loader import DatasetLoader | ||
from turbo_alignment.dataset.pair_preferences import PairPreferenceDataCollator | ||
from turbo_alignment.metrics.metric import Metric | ||
from turbo_alignment.metrics.registry import MetricSettingsRegistry | ||
from turbo_alignment.pipelines.train.base import BaseTrainStrategy | ||
from turbo_alignment.settings.datasets.base import DatasetStrategy | ||
from turbo_alignment.settings.pipelines.train.dpo import DPOTrainExperimentSettings | ||
from turbo_alignment.settings.pipelines.train.lddpo import LDDPOTrainExperimentSettings | ||
from turbo_alignment.trainers.lddpo import LDDPOTrainer, LDDPOTrainingArguments | ||
|
||
logger = get_project_logger() | ||
|
||
|
||
class TrainLDDPOStrategy(BaseTrainStrategy[LDDPOTrainExperimentSettings]): | ||
@staticmethod | ||
def _get_data_collator( | ||
experiment_settings: LDDPOTrainExperimentSettings, | ||
tokenizer: PreTrainedTokenizerBase, | ||
**kwargs, | ||
) -> Callable: | ||
return PairPreferenceDataCollator(tokenizer=tokenizer, add_labels=True) | ||
|
||
@staticmethod | ||
def _get_cherry_pick_callback( | ||
experiment_settings: LDDPOTrainExperimentSettings, | ||
tokenizer: PreTrainedTokenizerBase, | ||
**kwargs, | ||
) -> ChatCherryPickCallback: | ||
cherry_pick_settings = experiment_settings.cherry_pick_settings | ||
|
||
cherry_pick_datasets = DatasetLoader[InferenceChatDataset](InferenceChatDataset).load_datasets( | ||
cherry_pick_settings.dataset_settings, tokenizer=tokenizer, strategy=DatasetStrategy.INFERENCE | ||
) | ||
|
||
metrics = [ | ||
Metric.by_name(metric.type)(MetricSettingsRegistry.by_name(metric.type)(**metric.parameters)) | ||
for metric in cherry_pick_settings.metric_settings | ||
] | ||
|
||
return ChatCherryPickCallback( | ||
cherry_pick_settings=cherry_pick_settings, | ||
datasets=cherry_pick_datasets, | ||
metrics=metrics, | ||
) | ||
|
||
def _dataset_and_collator_sanity_check(self, dataset: Dataset, collator: DataCollatorMixin) -> None: | ||
logger.info(f'Train sample example:\n{dataset[0]}') | ||
|
||
logger.info( | ||
'Input-w check: {input_ids}'.format( | ||
input_ids=collator([dataset[0], dataset[1]])['inputs_w']['input_ids'][0] | ||
) | ||
) | ||
logger.info( | ||
'Mask-w check: {mask}'.format(mask=collator([dataset[0], dataset[1]])['inputs_w']['attention_mask'][0]) | ||
) | ||
logger.info( | ||
'Input-l check: {input_ids}'.format( | ||
input_ids=collator([dataset[0], dataset[1]])['inputs_l']['input_ids'][0] | ||
) | ||
) | ||
logger.info( | ||
'Mask-l check: {mask}'.format(mask=collator([dataset[0], dataset[1]])['inputs_l']['attention_mask'][0]) | ||
) | ||
|
||
@staticmethod | ||
def _get_training_args(experiment_settings: LDDPOTrainExperimentSettings) -> LDDPOTrainingArguments: | ||
return LDDPOTrainingArguments( | ||
output_dir=str(experiment_settings.log_path / TRAINER_LOGS_FOLDER), | ||
label_names=[], | ||
remove_unused_columns=False, | ||
**experiment_settings.trainer_settings.dict(), | ||
) | ||
|
||
@staticmethod | ||
def _get_trainer( | ||
training_args: LDDPOTrainingArguments, | ||
experiment_settings: DPOTrainExperimentSettings, | ||
model: PreTrainedModel, | ||
tokenizer: PreTrainedTokenizerBase, | ||
train_dataset: Dataset, | ||
val_dataset: Dataset, | ||
data_collator: Callable, | ||
): | ||
model.config.use_cache = not training_args.gradient_checkpointing | ||
|
||
extra_args = {} | ||
if experiment_settings.trainer_settings.use_ref_model: | ||
ref_model = load_model(experiment_settings.model_settings, tokenizer) | ||
for _, param in ref_model.named_parameters(): | ||
param.requires_grad = False | ||
|
||
extra_args['ref_model'] = ref_model | ||
|
||
if experiment_settings.trainer_settings.use_sft_model: | ||
sft_model = load_model(experiment_settings.model_settings, tokenizer) | ||
for _, param in sft_model.named_parameters(): | ||
param.requires_grad = False | ||
|
||
extra_args['sft_model'] = sft_model | ||
|
||
return LDDPOTrainer( | ||
model=model, | ||
args=training_args, | ||
train_dataset=train_dataset, | ||
eval_dataset=val_dataset, | ||
callbacks=[], | ||
data_collator=data_collator, | ||
tokenizer=tokenizer, | ||
**extra_args, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from turbo_alignment.settings.pipelines.train.dpo import ( | ||
DPOTrainerSettings, | ||
DPOTrainExperimentSettings, | ||
) | ||
|
||
|
||
class LDDPOTrainerSettings(DPOTrainerSettings): | ||
alpha: float = 1.0 | ||
|
||
|
||
class LDDPOTrainExperimentSettings(DPOTrainExperimentSettings): | ||
trainer_settings: LDDPOTrainerSettings |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.