From 49eff3d7816d381f1d53e8c1a00406ed1dd4e338 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=9C=D0=B0=D0=BB=D0=B0=D1=85=D0=BE=D0=B2=20=D0=90=D0=BB?= =?UTF-8?q?=D0=B5=D0=BA=D1=81=D0=B5=D0=B9=20=D0=9F=D0=B0=D0=B2=D0=BB=D0=BE?= =?UTF-8?q?=D0=B2=D0=B8=D1=87?= Date: Thu, 17 Oct 2024 19:37:00 +0000 Subject: [PATCH] fix tests and linters --- tests/cli/test_classification.py | 28 ++-- tests/cli/test_dpo_train.py | 42 +++--- tests/cli/test_kto_train.py | 36 ++--- tests/conftest.py | 134 +++++++++--------- .../configs/train/classification/base.json | 17 ++- .../multimodal/llama_c_abs_clip_pickle.json | 25 ++-- .../multimodal/llama_llava_base_clip.json | 25 ++-- .../multimodal/llama_llava_clip_pickle.json | 25 ++-- tests/fixtures/configs/train/rag/base.json | 22 +-- tests/fixtures/configs/train/rm/base.json | 17 ++- tests/fixtures/configs/train/sft/base.json | 22 +-- .../configs/train/sft/prompt_tuning.json | 6 +- .../configs/train/sft/sft_with_rm_metric.json | 22 +-- tests/integration/test_trainers.py | 104 +++++++------- tests/unit/test_collators.py | 42 +++--- tests/unit/test_datasets.py | 124 ++++++++-------- .../common/tf/loaders/model/model.py | 8 +- .../common/tf/loaders/model/registry.py | 16 --- .../settings/pipelines/train/dpo.py | 8 +- .../settings/pipelines/train/kto.py | 11 +- turbo_alignment/settings/s3.py | 8 -- turbo_alignment/settings/tf/peft.py | 27 ++-- turbo_alignment/trainers/classification.py | 2 +- turbo_alignment/trainers/dpo.py | 10 +- turbo_alignment/trainers/kto.py | 6 +- tutorials/multimodal/multimodal.json | 25 ++-- 26 files changed, 400 insertions(+), 412 deletions(-) diff --git a/tests/cli/test_classification.py b/tests/cli/test_classification.py index 8de6b4b..115e475 100755 --- a/tests/cli/test_classification.py +++ b/tests/cli/test_classification.py @@ -1,26 +1,24 @@ from pathlib import Path +import pytest from typer.testing import CliRunner from tests.constants import FIXTURES_PATH from turbo_alignment.cli import app +from turbo_alignment.settings.pipelines.train.classification import ( + ClassificationTrainExperimentSettings, +) runner = CliRunner() -def test_classification_train(): - result = runner.invoke( - app, - [ - 'train_classification', - '--experiment_settings_path', - FIXTURES_PATH / 'configs/train/classification/base.json', - ], - catch_exceptions=False, - ) +@pytest.mark.parametrize( + 'config_path', + [ + FIXTURES_PATH / 'configs/train/classification/base.json', + ], +) +def test_classification_train(config_path: Path): + result = runner.invoke(app, ['train_classification', '--experiment_settings_path', str(config_path)]) assert result.exit_code == 0 - assert Path('test_train_classification_output').is_dir() - - -if __name__ == '__main__': - test_classification_train() + assert ClassificationTrainExperimentSettings.parse_file(config_path).log_path.is_dir() diff --git a/tests/cli/test_dpo_train.py b/tests/cli/test_dpo_train.py index f652a64..23bd4af 100755 --- a/tests/cli/test_dpo_train.py +++ b/tests/cli/test_dpo_train.py @@ -1,26 +1,26 @@ -# from pathlib import Path +from pathlib import Path -# import pytest -# from typer.testing import CliRunner +import pytest +from typer.testing import CliRunner -# from tests.constants import FIXTURES_PATH -# from turbo_alignment.cli import app -# from turbo_alignment.settings.pipelines.train.dpo import DPOTrainExperimentSettings +from tests.constants import FIXTURES_PATH +from turbo_alignment.cli import app +from turbo_alignment.settings.pipelines.train.dpo import DPOTrainExperimentSettings -# runner = CliRunner() +runner = CliRunner() -# @pytest.mark.parametrize( -# 'config_path', -# [ -# FIXTURES_PATH / 'configs/train/dpo/base.json', -# FIXTURES_PATH / 'configs/train/dpo/simpo.json', -# ], -# ) -# def test_dpo_train(config_path: Path): -# result = runner.invoke( -# app, -# ['train_dpo', '--experiment_settings_path', str(config_path)], -# ) -# assert result.exit_code == 0 -# assert DPOTrainExperimentSettings.parse_file(config_path).log_path.is_dir() +@pytest.mark.parametrize( + 'config_path', + [ + FIXTURES_PATH / 'configs/train/dpo/base.json', + FIXTURES_PATH / 'configs/train/dpo/simpo.json', + ], +) +def test_dpo_train(config_path: Path): + result = runner.invoke( + app, + ['train_dpo', '--experiment_settings_path', str(config_path)], + ) + assert result.exit_code == 0 + assert DPOTrainExperimentSettings.parse_file(config_path).log_path.is_dir() diff --git a/tests/cli/test_kto_train.py b/tests/cli/test_kto_train.py index b7ddb7d..18cb27d 100755 --- a/tests/cli/test_kto_train.py +++ b/tests/cli/test_kto_train.py @@ -1,23 +1,23 @@ -# from pathlib import Path +from pathlib import Path -# import pytest -# from typer.testing import CliRunner +import pytest +from typer.testing import CliRunner -# from tests.constants import FIXTURES_PATH -# from turbo_alignment.cli import app -# from turbo_alignment.settings.pipelines.train.kto import KTOTrainExperimentSettings +from tests.constants import FIXTURES_PATH +from turbo_alignment.cli import app +from turbo_alignment.settings.pipelines.train.kto import KTOTrainExperimentSettings -# runner = CliRunner() +runner = CliRunner() -# @pytest.mark.parametrize( -# 'config_path', -# [FIXTURES_PATH / 'configs/train/kto/base.json'], -# ) -# def test_dpo_train(config_path: Path): -# result = runner.invoke( -# app, -# ['train_kto', '--experiment_settings_path', str(config_path)], -# ) -# assert result.exit_code == 0 -# assert KTOTrainExperimentSettings.parse_file(config_path).log_path.is_dir() +@pytest.mark.parametrize( + 'config_path', + [FIXTURES_PATH / 'configs/train/kto/base.json'], +) +def test_dpo_train(config_path: Path): + result = runner.invoke( + app, + ['train_kto', '--experiment_settings_path', str(config_path)], + ) + assert result.exit_code == 0 + assert KTOTrainExperimentSettings.parse_file(config_path).log_path.is_dir() diff --git a/tests/conftest.py b/tests/conftest.py index da98770..371a7f6 100755 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,96 +1,96 @@ -# import json +import json -# from pytest import fixture -# from transformers import AutoTokenizer +from pytest import fixture +from transformers import AutoTokenizer -# from turbo_alignment.dataset.registry import DatasetRegistry -# from turbo_alignment.settings.datasets.base import ( -# DatasetSourceSettings, -# DatasetStrategy, -# DatasetType, -# ) -# from turbo_alignment.settings.datasets.chat import ( -# ChatDatasetSettings, -# ChatPromptTemplate, -# ) -# from turbo_alignment.settings.datasets.pair_preference import ( -# PairPreferenceDatasetSettings, -# ) +from turbo_alignment.dataset.registry import DatasetRegistry +from turbo_alignment.settings.datasets.base import ( + DatasetSourceSettings, + DatasetStrategy, + DatasetType, +) +from turbo_alignment.settings.datasets.chat import ( + ChatDatasetSettings, + ChatPromptTemplate, +) +from turbo_alignment.settings.datasets.pair_preference import ( + PairPreferenceDatasetSettings, +) -# @fixture(scope='session') -# def tokenizer_llama2(): -# tokenizer_path = 'tests/fixtures/models/llama2_classification/tokenizer' -# tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) -# return tokenizer +@fixture(scope='session') +def tokenizer_llama2(): + tokenizer_path = 'tests/fixtures/models/llama2_classification/tokenizer' + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + return tokenizer -# @fixture(scope='session') -# def tokenizer_gptj(): -# tokenizer_path = 'tests/fixtures/models/gptj_tiny_for_seq_cls' -# tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) -# return tokenizer +@fixture(scope='session') +def tokenizer_gptj(): + tokenizer_path = 'tests/fixtures/models/gptj_tiny_for_seq_cls' + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + return tokenizer -# @fixture(scope='session') -# def chat_dataset_settings(): -# chat_dataset_settings = ChatDatasetSettings( -# prompt_template=ChatPromptTemplate( -# prefix_template='', -# suffix_template='', -# role_tag_mapping={'user': 'USER', 'bot': 'BOT', 'system': 'SYSTEM'}, -# ), -# max_tokens_count=8256, -# ) -# return chat_dataset_settings +@fixture(scope='session') +def chat_dataset_settings(): + chat_dataset_settings = ChatDatasetSettings( + prompt_template=ChatPromptTemplate( + prefix_template='', + suffix_template='', + role_tag_mapping={'user': 'USER', 'bot': 'BOT', 'system': 'SYSTEM'}, + ), + max_tokens_count=8256, + ) + return chat_dataset_settings -# @fixture(scope='session') -# def classification_dataset_path() -> str: -# return 'tests/fixtures/datasets/classification/train_classification.jsonl' +@fixture(scope='session') +def classification_dataset_path() -> str: + return 'tests/fixtures/datasets/classification/train_classification.jsonl' -# @fixture(scope='session') -# def pair_preferences_dataset_path() -> str: -# return 'tests/fixtures/datasets/rm/train_preferences.jsonl' +@fixture(scope='session') +def pair_preferences_dataset_path() -> str: + return 'tests/fixtures/datasets/rm/train_preferences.jsonl' -# @fixture(scope='session') -# def kto_dataset_path() -> str: -# return 'tests/fixtures/datasets/rm/train_kto.jsonl' +@fixture(scope='session') +def kto_dataset_path() -> str: + return 'tests/fixtures/datasets/rm/train_kto.jsonl' -# def load_dataset_source(dataset_path: str) -> tuple[DatasetSourceSettings, list[dict]]: -# with open(dataset_path, 'r', encoding='utf-8') as f: -# data_dicts = [json.loads(line) for line in f] +def load_dataset_source(dataset_path: str) -> tuple[DatasetSourceSettings, list[dict]]: + with open(dataset_path, 'r', encoding='utf-8') as f: + data_dicts = [json.loads(line) for line in f] -# source = DatasetSourceSettings(name='dataset_for_test', records_path=dataset_path, sample_rate=1) + source = DatasetSourceSettings(name='dataset_for_test', records_path=dataset_path, sample_rate=1) -# return source, data_dicts + return source, data_dicts -# @fixture(scope='session') -# def classification_dataset_source(classification_dataset_path) -> tuple[DatasetSourceSettings, list[dict]]: -# return load_dataset_source(classification_dataset_path) +@fixture(scope='session') +def classification_dataset_source(classification_dataset_path) -> tuple[DatasetSourceSettings, list[dict]]: + return load_dataset_source(classification_dataset_path) -# @fixture(scope='session') -# def pair_preferences_dataset_source(pair_preferences_dataset_path) -> tuple[DatasetSourceSettings, list[dict]]: -# return load_dataset_source(pair_preferences_dataset_path) +@fixture(scope='session') +def pair_preferences_dataset_source(pair_preferences_dataset_path) -> tuple[DatasetSourceSettings, list[dict]]: + return load_dataset_source(pair_preferences_dataset_path) -# @fixture(scope='session') -# def kto_dataset_source(kto_dataset_path) -> tuple[DatasetSourceSettings, list[dict]]: -# return load_dataset_source(kto_dataset_path) +@fixture(scope='session') +def kto_dataset_source(kto_dataset_path) -> tuple[DatasetSourceSettings, list[dict]]: + return load_dataset_source(kto_dataset_path) -# @fixture(scope='session') -# def dpo_dataset(pair_preferences_dataset_source, tokenizer_llama2, chat_dataset_settings): -# source, _ = pair_preferences_dataset_source +@fixture(scope='session') +def dpo_dataset(pair_preferences_dataset_source, tokenizer_llama2, chat_dataset_settings): + source, _ = pair_preferences_dataset_source -# dataset_cls = DatasetRegistry.by_name(DatasetType.PAIR_PREFERENCES).by_name(DatasetStrategy.TRAIN) + dataset_cls = DatasetRegistry.by_name(DatasetType.PAIR_PREFERENCES).by_name(DatasetStrategy.TRAIN) -# dataset_settings = PairPreferenceDatasetSettings(chat_settings=chat_dataset_settings) -# dataset = dataset_cls(tokenizer=tokenizer_llama2, source=source, settings=dataset_settings) + dataset_settings = PairPreferenceDatasetSettings(chat_settings=chat_dataset_settings) + dataset = dataset_cls(tokenizer=tokenizer_llama2, source=source, settings=dataset_settings) -# return dataset + return dataset diff --git a/tests/fixtures/configs/train/classification/base.json b/tests/fixtures/configs/train/classification/base.json index 960188b..eb396fe 100755 --- a/tests/fixtures/configs/train/classification/base.json +++ b/tests/fixtures/configs/train/classification/base.json @@ -79,11 +79,18 @@ "problem_type": "single_label_classification" }, "peft_settings": { - "r": 8, - "lora_alpha": 16, - "lora_dropout": 0.05, - "target_modules": ["q_proj", "v_proj"], - "task_type": "SEQ_CLS" + "name": "LORA", + "config": { + "r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "target_modules": [ + "q_proj", + "v_proj" + ], + "task_type": "CAUSAL_LM", + "modules_to_save": ["embed_tokens", "lm_head"] + } } }, "tokenizer_settings": {}, diff --git a/tests/fixtures/configs/train/multimodal/llama_c_abs_clip_pickle.json b/tests/fixtures/configs/train/multimodal/llama_c_abs_clip_pickle.json index ffb2b13..6a03d75 100644 --- a/tests/fixtures/configs/train/multimodal/llama_c_abs_clip_pickle.json +++ b/tests/fixtures/configs/train/multimodal/llama_c_abs_clip_pickle.json @@ -76,19 +76,18 @@ "model_type": "causal", "transformers_settings": {}, "peft_settings": { - "r": 16, - "lora_alpha": 16, - "lora_dropout": 0.05, - "target_modules": [ - "q_proj", - "k_proj" - ], - "task_type": "CAUSAL_LM", - "modules_to_save": [ - "embed_tokens", - "lm_head" - ], - "name": "LORA" + "name": "LORA", + "config": { + "r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "target_modules": [ + "q_proj", + "v_proj" + ], + "task_type": "CAUSAL_LM", + "modules_to_save": ["embed_tokens", "lm_head"] + } }, "embeddings_initialization_strategy": { "": "bot", diff --git a/tests/fixtures/configs/train/multimodal/llama_llava_base_clip.json b/tests/fixtures/configs/train/multimodal/llama_llava_base_clip.json index ef90914..cd1b153 100644 --- a/tests/fixtures/configs/train/multimodal/llama_llava_base_clip.json +++ b/tests/fixtures/configs/train/multimodal/llama_llava_base_clip.json @@ -76,19 +76,18 @@ "model_type": "causal", "transformers_settings": {}, "peft_settings": { - "r": 16, - "lora_alpha": 16, - "lora_dropout": 0.05, - "target_modules": [ - "q_proj", - "k_proj" - ], - "task_type": "CAUSAL_LM", - "modules_to_save": [ - "embed_tokens", - "lm_head" - ], - "name": "LORA" + "name": "LORA", + "config": { + "r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "target_modules": [ + "q_proj", + "v_proj" + ], + "task_type": "CAUSAL_LM", + "modules_to_save": ["embed_tokens", "lm_head"] + } }, "embeddings_initialization_strategy": { "": "bot", diff --git a/tests/fixtures/configs/train/multimodal/llama_llava_clip_pickle.json b/tests/fixtures/configs/train/multimodal/llama_llava_clip_pickle.json index eae875a..1dc8f63 100644 --- a/tests/fixtures/configs/train/multimodal/llama_llava_clip_pickle.json +++ b/tests/fixtures/configs/train/multimodal/llama_llava_clip_pickle.json @@ -76,19 +76,18 @@ "model_type": "causal", "transformers_settings": {}, "peft_settings": { - "r": 16, - "lora_alpha": 16, - "lora_dropout": 0.05, - "target_modules": [ - "q_proj", - "k_proj" - ], - "task_type": "CAUSAL_LM", - "modules_to_save": [ - "embed_tokens", - "lm_head" - ], - "name": "LORA" + "name": "LORA", + "config": { + "r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "target_modules": [ + "q_proj", + "v_proj" + ], + "task_type": "CAUSAL_LM", + "modules_to_save": ["embed_tokens", "lm_head"] + } }, "embeddings_initialization_strategy": { "": "bot", diff --git a/tests/fixtures/configs/train/rag/base.json b/tests/fixtures/configs/train/rag/base.json index b64b826..d5f1e4c 100755 --- a/tests/fixtures/configs/train/rag/base.json +++ b/tests/fixtures/configs/train/rag/base.json @@ -54,16 +54,18 @@ "": "system" }, "peft_settings": { - "r": 16, - "lora_alpha": 16, - "lora_dropout": 0.05, - "target_modules": ["q_proj", "v_proj", "k_proj", "o_proj"], - "task_type": "CAUSAL_LM", - "modules_to_save": [ - "embed_tokens", - "lm_head" - ], - "name": "LORA" + "name": "LORA", + "config": { + "r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "target_modules": [ + "q_proj", + "v_proj" + ], + "task_type": "CAUSAL_LM", + "modules_to_save": ["embed_tokens", "lm_head"] + } } }, "question_encoder_settings": { diff --git a/tests/fixtures/configs/train/rm/base.json b/tests/fixtures/configs/train/rm/base.json index c2dd749..c9a3a6b 100755 --- a/tests/fixtures/configs/train/rm/base.json +++ b/tests/fixtures/configs/train/rm/base.json @@ -81,11 +81,18 @@ "return_dict": true }, "peft_settings": { - "r": 8, - "lora_alpha": 16, - "lora_dropout": 0.05, - "target_modules": ["q_proj", "v_proj"], - "task_type": "SEQ_CLS" + "name": "LORA", + "config": { + "r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "target_modules": [ + "q_proj", + "v_proj" + ], + "task_type": "SEQ_CLS", + "modules_to_save": ["embed_tokens", "lm_head"] + } } }, "tokenizer_settings": {}, diff --git a/tests/fixtures/configs/train/sft/base.json b/tests/fixtures/configs/train/sft/base.json index 99c1ac0..c83f08b 100755 --- a/tests/fixtures/configs/train/sft/base.json +++ b/tests/fixtures/configs/train/sft/base.json @@ -47,16 +47,18 @@ "transformers_settings": { }, "peft_settings": { - "r": 8, - "lora_alpha": 16, - "lora_dropout": 0.05, - "target_modules": [ - "q_proj", - "v_proj" - ], - "task_type": "CAUSAL_LM", - "modules_to_save": ["embed_tokens", "lm_head"], - "name": "LORA" + "name": "LORA", + "config": { + "r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "target_modules": [ + "q_proj", + "v_proj" + ], + "task_type": "CAUSAL_LM", + "modules_to_save": ["embed_tokens", "lm_head"] + } }, "embeddings_initialization_strategy": { "": "", diff --git a/tests/fixtures/configs/train/sft/prompt_tuning.json b/tests/fixtures/configs/train/sft/prompt_tuning.json index 920d1f3..fc58e42 100755 --- a/tests/fixtures/configs/train/sft/prompt_tuning.json +++ b/tests/fixtures/configs/train/sft/prompt_tuning.json @@ -46,9 +46,11 @@ "model_type": "causal", "transformers_settings": {}, "peft_settings": { - "task_type": "CAUSAL_LM", "name": "PROMPT_TUNING", - "num_virtual_tokens": 32 + "config": { + "task_type": "CAUSAL_LM", + "num_virtual_tokens": 32 + } }, "embeddings_initialization_strategy": { "": "", diff --git a/tests/fixtures/configs/train/sft/sft_with_rm_metric.json b/tests/fixtures/configs/train/sft/sft_with_rm_metric.json index 8807122..ec6526c 100755 --- a/tests/fixtures/configs/train/sft/sft_with_rm_metric.json +++ b/tests/fixtures/configs/train/sft/sft_with_rm_metric.json @@ -47,16 +47,18 @@ "transformers_settings": { }, "peft_settings": { - "r": 8, - "lora_alpha": 16, - "lora_dropout": 0.05, - "target_modules": [ - "q_proj", - "v_proj" - ], - "task_type": "CAUSAL_LM", - "modules_to_save": ["embed_tokens", "lm_head"], - "name": "LORA" + "name": "LORA", + "config": { + "r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "target_modules": [ + "q_proj", + "v_proj" + ], + "task_type": "CAUSAL_LM", + "modules_to_save": ["embed_tokens", "lm_head"] + } }, "embeddings_initialization_strategy": { "": "", diff --git a/tests/integration/test_trainers.py b/tests/integration/test_trainers.py index 56941bb..97f8aa8 100644 --- a/tests/integration/test_trainers.py +++ b/tests/integration/test_trainers.py @@ -1,69 +1,69 @@ -# from tempfile import TemporaryDirectory +from tempfile import TemporaryDirectory -# import torch -# from transformers import AutoModelForCausalLM, AutoTokenizer +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer -# from turbo_alignment.common.tf.callbacks.sync_ref_model import SyncRefModelSettings -# from turbo_alignment.dataset.pair_preferences import PairPreferenceDataCollator -# from turbo_alignment.settings.pipelines.train.dpo import ( -# DPOLossesType, -# SigmoidLossSettings, -# ) -# from turbo_alignment.trainers.dpo import DPOTrainer, DPOTrainingArguments +from turbo_alignment.common.tf.callbacks.sync_ref_model import SyncRefModelSettings +from turbo_alignment.dataset.pair_preferences import PairPreferenceDataCollator +from turbo_alignment.settings.pipelines.train.utils import ( + DPOLossesType, + SigmoidLossSettings, +) +from turbo_alignment.trainers.dpo import DPOTrainer, DPOTrainingArguments -# def test_dpo_trainer(dpo_dataset): -# model_path = 'tests/fixtures/models/llama2_tiny' +def test_dpo_trainer(dpo_dataset): + model_path = 'tests/fixtures/models/llama2_tiny' -# model = AutoModelForCausalLM.from_pretrained(model_path) + model = AutoModelForCausalLM.from_pretrained(model_path) -# ref_model = AutoModelForCausalLM.from_pretrained(model_path) + ref_model = AutoModelForCausalLM.from_pretrained(model_path) -# with TemporaryDirectory() as tmp_dir: -# args = DPOTrainingArguments( -# do_train=True, -# loss_settings=SigmoidLossSettings(loss_type=DPOLossesType.SIGMOID).dict(), -# sync_ref_settings=SyncRefModelSettings().dict(), -# do_eval=False, -# learning_rate=1.0e-4, -# use_cpu=True, -# num_train_epochs=10, -# report_to=[], -# remove_unused_columns=False, -# output_dir=tmp_dir, -# ) + with TemporaryDirectory() as tmp_dir: + args = DPOTrainingArguments( + do_train=True, + loss_settings=SigmoidLossSettings(loss_type=DPOLossesType.SIGMOID).dict(), + sync_ref_settings=SyncRefModelSettings().dict(), + do_eval=False, + learning_rate=1.0e-4, + use_cpu=True, + num_train_epochs=10, + report_to=[], + remove_unused_columns=False, + output_dir=tmp_dir, + ) -# tokenizer = AutoTokenizer.from_pretrained(model_path) + tokenizer = AutoTokenizer.from_pretrained(model_path) -# data_collator = PairPreferenceDataCollator(tokenizer=tokenizer) + data_collator = PairPreferenceDataCollator(tokenizer=tokenizer) -# trainer = DPOTrainer( -# model=model, -# ref_model=ref_model, -# args=args, -# train_dataset=dpo_dataset, -# eval_dataset=dpo_dataset, -# data_collator=data_collator, -# ) + trainer = DPOTrainer( + model=model, + ref_model=ref_model, + args=args, + train_dataset=dpo_dataset, + eval_dataset=dpo_dataset, + data_collator=data_collator, + ) -# batch = data_collator(list(dpo_dataset)) + batch = data_collator(list(dpo_dataset)) -# loss_before, _ = trainer.get_batch_metrics(model, batch, 'train') -# trainer.train() -# loss_after, _ = trainer.get_batch_metrics(trainer.model, batch, 'train') + loss_before, _ = trainer.get_batch_metrics(model, batch, 'train') + trainer.train() + loss_after, _ = trainer.get_batch_metrics(trainer.model, batch, 'train') -# assert torch.greater(loss_before, loss_after) + assert torch.greater(loss_before, loss_after) -# initial_model = AutoModelForCausalLM.from_pretrained(model_path) + initial_model = AutoModelForCausalLM.from_pretrained(model_path) -# trainer.save_model(tmp_dir) -# trained_model = AutoModelForCausalLM.from_pretrained(tmp_dir) + trainer.save_model(tmp_dir) + trained_model = AutoModelForCausalLM.from_pretrained(tmp_dir) -# initial_state_dict = initial_model.state_dict() -# trained_state_dict = trained_model.state_dict() -# for k, v in trained_state_dict.items(): -# assert any(torch.not_equal(v, initial_state_dict[k]).tolist()) + initial_state_dict = initial_model.state_dict() + trained_state_dict = trained_model.state_dict() + for k, v in trained_state_dict.items(): + assert any(torch.not_equal(v, initial_state_dict[k]).tolist()) -# ref_model_state_dict = trainer.ref_model.state_dict() -# for k, v in ref_model_state_dict.items(): -# assert all(torch.eq(v, initial_state_dict[k]).tolist()) + ref_model_state_dict = trainer.ref_model.state_dict() + for k, v in ref_model_state_dict.items(): + assert all(torch.eq(v, initial_state_dict[k]).tolist()) diff --git a/tests/unit/test_collators.py b/tests/unit/test_collators.py index 5231b40..150436a 100644 --- a/tests/unit/test_collators.py +++ b/tests/unit/test_collators.py @@ -1,28 +1,28 @@ -# from tests.utils import is_sample_build_from_content -# from turbo_alignment.dataset.kto.collators import KTODataCollator -# from turbo_alignment.dataset.registry import DatasetRegistry -# from turbo_alignment.settings.datasets.base import DatasetStrategy, DatasetType -# from turbo_alignment.settings.datasets.kto import KTODatasetSettings +from tests.utils import is_sample_build_from_content +from turbo_alignment.dataset.kto.collators import KTODataCollator +from turbo_alignment.dataset.registry import DatasetRegistry +from turbo_alignment.settings.datasets.base import DatasetStrategy, DatasetType +from turbo_alignment.settings.datasets.kto import KTODatasetSettings -# def test_kto_collator(tokenizer_llama2, chat_dataset_settings, kto_dataset_source): -# tokenizer = tokenizer_llama2 -# source, data_dicts = kto_dataset_source +def test_kto_collator(tokenizer_llama2, chat_dataset_settings, kto_dataset_source): + tokenizer = tokenizer_llama2 + source, data_dicts = kto_dataset_source -# dataset_cls = DatasetRegistry.by_name(DatasetType.KTO).by_name(DatasetStrategy.TRAIN) + dataset_cls = DatasetRegistry.by_name(DatasetType.KTO).by_name(DatasetStrategy.TRAIN) -# dataset_settings = KTODatasetSettings(chat_settings=chat_dataset_settings) -# dataset = dataset_cls(tokenizer=tokenizer, source=source, settings=dataset_settings) + dataset_settings = KTODatasetSettings(chat_settings=chat_dataset_settings) + dataset = dataset_cls(tokenizer=tokenizer, source=source, settings=dataset_settings) -# batch_size = min(len(dataset), 8) -# examples = list(dataset)[:batch_size] + batch_size = min(len(dataset), 8) + examples = list(dataset)[:batch_size] -# collator = KTODataCollator(tokenizer=tokenizer) -# batch = collator(examples) + collator = KTODataCollator(tokenizer=tokenizer) + batch = collator(examples) -# ignore_index = -100 -# for answer_labels, is_desirable, raw_data in zip(batch['labels'], batch['is_desirable'], data_dicts): -# answer = raw_data['answer']['content'] -# answer_tokens = answer_labels[answer_labels != ignore_index] -# assert is_sample_build_from_content(answer_tokens, [answer], tokenizer) -# assert raw_data['is_desirable'] == is_desirable + ignore_index = -100 + for answer_labels, is_desirable, raw_data in zip(batch['labels'], batch['is_desirable'], data_dicts): + answer = raw_data['answer']['content'] + answer_tokens = answer_labels[answer_labels != ignore_index] + assert is_sample_build_from_content(answer_tokens, [answer], tokenizer) + assert raw_data['is_desirable'] == is_desirable diff --git a/tests/unit/test_datasets.py b/tests/unit/test_datasets.py index 089417b..1ecd323 100644 --- a/tests/unit/test_datasets.py +++ b/tests/unit/test_datasets.py @@ -1,88 +1,88 @@ -# from tests.utils import is_sample_build_from_content -# from turbo_alignment.dataset.classification.models import ClassificationDatasetRecord -# from turbo_alignment.dataset.pair_preferences.models import PairPreferenceRecord -# from turbo_alignment.dataset.registry import DatasetRegistry -# from turbo_alignment.settings.datasets.base import DatasetStrategy, DatasetType -# from turbo_alignment.settings.datasets.classification import ( -# ClassificationDatasetSettings, -# ) -# from turbo_alignment.settings.datasets.ddpo import DDPODatasetSettings -# from turbo_alignment.settings.datasets.pair_preference import ( -# PairPreferenceDatasetSettings, -# ) +from tests.utils import is_sample_build_from_content +from turbo_alignment.dataset.classification.models import ClassificationDatasetRecord +from turbo_alignment.dataset.pair_preferences.models import PairPreferenceRecord +from turbo_alignment.dataset.registry import DatasetRegistry +from turbo_alignment.settings.datasets.base import DatasetStrategy, DatasetType +from turbo_alignment.settings.datasets.classification import ( + ClassificationDatasetSettings, +) +from turbo_alignment.settings.datasets.ddpo import DDPODatasetSettings +from turbo_alignment.settings.datasets.pair_preference import ( + PairPreferenceDatasetSettings, +) -# def test_classification(tokenizer_llama2, chat_dataset_settings, classification_dataset_source): -# # load dataset and check that samples have required fields +def test_classification(tokenizer_llama2, chat_dataset_settings, classification_dataset_source): + # load dataset and check that samples have required fields -# source, data_dicts = classification_dataset_source + source, data_dicts = classification_dataset_source -# dataset_cls = DatasetRegistry.by_name(DatasetType.CLASSIFICATION).by_name(DatasetStrategy.TRAIN) + dataset_cls = DatasetRegistry.by_name(DatasetType.CLASSIFICATION).by_name(DatasetStrategy.TRAIN) -# dataset_settings = ClassificationDatasetSettings(chat_settings=chat_dataset_settings) + dataset_settings = ClassificationDatasetSettings(chat_settings=chat_dataset_settings) -# dataset = dataset_cls(tokenizer=tokenizer_llama2, source=source, settings=dataset_settings) + dataset = dataset_cls(tokenizer=tokenizer_llama2, source=source, settings=dataset_settings) -# assert len(data_dicts) == len(dataset) + assert len(data_dicts) == len(dataset) -# for data_dict, sample in zip(data_dicts, dataset): -# record = ClassificationDatasetRecord.model_validate(data_dict) + for data_dict, sample in zip(data_dicts, dataset): + record = ClassificationDatasetRecord.model_validate(data_dict) -# assert record.label == sample['labels'] + assert record.label == sample['labels'] -# assert is_sample_build_from_content( -# sample['input_ids'], [m.content for m in record.messages], tokenizer_llama2 -# ) + assert is_sample_build_from_content( + sample['input_ids'], [m.content for m in record.messages], tokenizer_llama2 + ) -# def test_pair_preferences(tokenizer_llama2, chat_dataset_settings, pair_preferences_dataset_source): -# # load dataset and check that samples have required fields +def test_pair_preferences(tokenizer_llama2, chat_dataset_settings, pair_preferences_dataset_source): + # load dataset and check that samples have required fields -# source, data_dicts = pair_preferences_dataset_source + source, data_dicts = pair_preferences_dataset_source -# dataset_cls = DatasetRegistry.by_name(DatasetType.PAIR_PREFERENCES).by_name(DatasetStrategy.TRAIN) + dataset_cls = DatasetRegistry.by_name(DatasetType.PAIR_PREFERENCES).by_name(DatasetStrategy.TRAIN) -# dataset_settings = PairPreferenceDatasetSettings(chat_settings=chat_dataset_settings) -# dataset = dataset_cls(tokenizer=tokenizer_llama2, source=source, settings=dataset_settings) + dataset_settings = PairPreferenceDatasetSettings(chat_settings=chat_dataset_settings) + dataset = dataset_cls(tokenizer=tokenizer_llama2, source=source, settings=dataset_settings) -# assert len(data_dicts) == len(dataset) + assert len(data_dicts) == len(dataset) -# for data_dict, sample in zip(data_dicts, dataset): -# record = PairPreferenceRecord.model_validate(data_dict) -# context: list[str] = [c.content for c in record.context] -# contents_w = [*context, record.answer_w.content] -# assert is_sample_build_from_content(sample['inputs_w']['input_ids'], contents_w, tokenizer_llama2) + for data_dict, sample in zip(data_dicts, dataset): + record = PairPreferenceRecord.model_validate(data_dict) + context: list[str] = [c.content for c in record.context] + contents_w = [*context, record.answer_w.content] + assert is_sample_build_from_content(sample['inputs_w']['input_ids'], contents_w, tokenizer_llama2) -# contents_l = [*context, record.answer_l.content] -# assert is_sample_build_from_content(sample['inputs_l']['input_ids'], contents_l, tokenizer_llama2) + contents_l = [*context, record.answer_l.content] + assert is_sample_build_from_content(sample['inputs_l']['input_ids'], contents_l, tokenizer_llama2) -# def test_ddpo(tokenizer_llama2, tokenizer_gptj, chat_dataset_settings, pair_preferences_dataset_source): -# sft_tokenizer = tokenizer_llama2 -# rm_tokenizer = tokenizer_gptj -# # load dataset and check that samples have required fields +def test_ddpo(tokenizer_llama2, tokenizer_gptj, chat_dataset_settings, pair_preferences_dataset_source): + sft_tokenizer = tokenizer_llama2 + rm_tokenizer = tokenizer_gptj + # load dataset and check that samples have required fields -# source, data_dicts = pair_preferences_dataset_source + source, data_dicts = pair_preferences_dataset_source -# dataset_cls = DatasetRegistry.by_name(DatasetType.DDPO).by_name(DatasetStrategy.TRAIN) + dataset_cls = DatasetRegistry.by_name(DatasetType.DDPO).by_name(DatasetStrategy.TRAIN) -# pair_preferences_dataset_settings = PairPreferenceDatasetSettings(chat_settings=chat_dataset_settings) -# dataset_settings = DDPODatasetSettings( -# chat_settings=chat_dataset_settings, pair_preferences=pair_preferences_dataset_settings -# ) -# dataset = dataset_cls( -# chat_tokenizer=sft_tokenizer, rm_tokenizer=rm_tokenizer, source=source, settings=dataset_settings -# ) + pair_preferences_dataset_settings = PairPreferenceDatasetSettings(chat_settings=chat_dataset_settings) + dataset_settings = DDPODatasetSettings( + chat_settings=chat_dataset_settings, pair_preferences=pair_preferences_dataset_settings + ) + dataset = dataset_cls( + chat_tokenizer=sft_tokenizer, rm_tokenizer=rm_tokenizer, source=source, settings=dataset_settings + ) -# assert len(data_dicts) == len(dataset) + assert len(data_dicts) == len(dataset) -# for data_dict, sample in zip(data_dicts, dataset): -# record = PairPreferenceRecord.model_validate(data_dict) -# context: list[str] = [c.content for c in record.context] -# contents_w = [*context, record.answer_w.content] -# assert is_sample_build_from_content(sample['sft_inputs_w']['input_ids'], contents_w, sft_tokenizer) -# assert is_sample_build_from_content(sample['rm_inputs_w']['input_ids'], contents_w, rm_tokenizer) + for data_dict, sample in zip(data_dicts, dataset): + record = PairPreferenceRecord.model_validate(data_dict) + context: list[str] = [c.content for c in record.context] + contents_w = [*context, record.answer_w.content] + assert is_sample_build_from_content(sample['sft_inputs_w']['input_ids'], contents_w, sft_tokenizer) + assert is_sample_build_from_content(sample['rm_inputs_w']['input_ids'], contents_w, rm_tokenizer) -# contents_l = [*context, record.answer_l.content] -# assert is_sample_build_from_content(sample['sft_inputs_l']['input_ids'], contents_l, sft_tokenizer) -# assert is_sample_build_from_content(sample['rm_inputs_l']['input_ids'], contents_l, rm_tokenizer) + contents_l = [*context, record.answer_l.content] + assert is_sample_build_from_content(sample['sft_inputs_l']['input_ids'], contents_l, sft_tokenizer) + assert is_sample_build_from_content(sample['rm_inputs_l']['input_ids'], contents_l, rm_tokenizer) diff --git a/turbo_alignment/common/tf/loaders/model/model.py b/turbo_alignment/common/tf/loaders/model/model.py index aba80e9..49a5f11 100755 --- a/turbo_alignment/common/tf/loaders/model/model.py +++ b/turbo_alignment/common/tf/loaders/model/model.py @@ -4,7 +4,6 @@ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from turbo_alignment.common.tf.loaders.model.registry import ( - PeftConfigRegistry, TransformersAutoModelRegistry, ) from turbo_alignment.modeling.liger_kernels import apply_liger_kernel_to_gemma2 @@ -18,12 +17,7 @@ def _prepare_model_for_peft(model: PreTrainedModel, peft_settings: PEFT_TYPE) -> PeftModel: - peft_params = peft_settings.dict() - peft_params.pop('name') - - peft_config = PeftConfigRegistry.by_name(peft_settings.name)(**peft_params) - - return get_peft_model(model, peft_config) + return get_peft_model(model, peft_settings.config) def _load_pretrained_adapters( diff --git a/turbo_alignment/common/tf/loaders/model/registry.py b/turbo_alignment/common/tf/loaders/model/registry.py index e2c5e08..565a61d 100755 --- a/turbo_alignment/common/tf/loaders/model/registry.py +++ b/turbo_alignment/common/tf/loaders/model/registry.py @@ -1,10 +1,3 @@ -from peft import ( - LoraConfig, - PeftType, - PrefixTuningConfig, - PromptEncoderConfig, - PromptTuningConfig, -) from transformers import ( AutoModel, AutoModelForCausalLM, @@ -19,15 +12,6 @@ class TransformersAutoModelRegistry(Registrable): ... -class PeftConfigRegistry(Registrable): - ... - - TransformersAutoModelRegistry.register(ModelType.CAUSAL)(AutoModelForCausalLM) TransformersAutoModelRegistry.register(ModelType.SEQ_CLS)(AutoModelForSequenceClassification) TransformersAutoModelRegistry.register(ModelType.ENC)(AutoModel) - -PeftConfigRegistry.register(PeftType.LORA)(LoraConfig) -PeftConfigRegistry.register(PeftType.PREFIX_TUNING)(PrefixTuningConfig) -PeftConfigRegistry.register(PeftType.PROMPT_TUNING)(PromptTuningConfig) -PeftConfigRegistry.register(PeftType.P_TUNING)(PromptEncoderConfig) diff --git a/turbo_alignment/settings/pipelines/train/dpo.py b/turbo_alignment/settings/pipelines/train/dpo.py index 0270063..795b84b 100755 --- a/turbo_alignment/settings/pipelines/train/dpo.py +++ b/turbo_alignment/settings/pipelines/train/dpo.py @@ -17,8 +17,8 @@ class DPOTrainExperimentSettings(BaseTrainExperimentSettings): cherry_pick_settings: ChatCherryPickSettings - # training_arguments: DPOTrainingArguments + training_arguments: DPOTrainingArguments - # @field_validator('training_arguments', mode='before') - # def convert_generation_settings(cls, values: dict[str, Any]) -> DPOTrainingArguments: - # return DPOTrainingArguments(**values, output_dir=TRAINER_LOGS_FOLDER, report_to=[]) + @field_validator('training_arguments', mode='before') + def create_training_arguments(cls, values: dict[str, Any]) -> DPOTrainingArguments: + return DPOTrainingArguments(**values, output_dir=TRAINER_LOGS_FOLDER, report_to=[]) diff --git a/turbo_alignment/settings/pipelines/train/kto.py b/turbo_alignment/settings/pipelines/train/kto.py index 0d459ef..bdf09f3 100755 --- a/turbo_alignment/settings/pipelines/train/kto.py +++ b/turbo_alignment/settings/pipelines/train/kto.py @@ -1,3 +1,8 @@ +from typing import Any + +from pydantic import field_validator + +from turbo_alignment.constants import TRAINER_LOGS_FOLDER from turbo_alignment.settings.cherry_pick import ChatCherryPickSettings from turbo_alignment.settings.datasets.kto import KTOMultiDatasetSettings from turbo_alignment.settings.pipelines.train.base import BaseTrainExperimentSettings @@ -10,4 +15,8 @@ class KTOTrainExperimentSettings(BaseTrainExperimentSettings): cherry_pick_settings: ChatCherryPickSettings - # training_arguments: KTOTrainingArguments + training_arguments: KTOTrainingArguments + + @field_validator('training_arguments', mode='before') + def create_training_arguments(cls, values: dict[str, Any]) -> KTOTrainingArguments: + return KTOTrainingArguments(**values, output_dir=TRAINER_LOGS_FOLDER, report_to=[]) diff --git a/turbo_alignment/settings/s3.py b/turbo_alignment/settings/s3.py index 2e964fc..22c63ca 100755 --- a/turbo_alignment/settings/s3.py +++ b/turbo_alignment/settings/s3.py @@ -1,6 +1,5 @@ from pathlib import Path -from pydantic import field_validator from pydantic_settings import BaseSettings @@ -14,13 +13,6 @@ class S3HandlerParameters(BaseSettings): aws_access_key_id: str aws_secret_access_key: str - @field_validator('bucket') - def bucket_name_biglm_is_not_allowed(cls, values: str) -> str: - if values == 'biglm': - raise S3HandlerParametersWrongBucketException('Usage of biglm bucket is not allowed') - - return values - class Config: env_file: str = '.env' env_prefix: str = 'S3_CHECKPOINTS_' diff --git a/turbo_alignment/settings/tf/peft.py b/turbo_alignment/settings/tf/peft.py index f479a27..d8aafd1 100755 --- a/turbo_alignment/settings/tf/peft.py +++ b/turbo_alignment/settings/tf/peft.py @@ -1,41 +1,40 @@ from typing import Literal -from peft import PeftType, TaskType +from peft import ( + LoraConfig, + PeftConfig, + PeftType, + PrefixTuningConfig, + PromptEncoderConfig, + PromptTuningConfig, +) from turbo_alignment.settings.base import ExtraFieldsNotAllowedBaseModel class BasePeftSettings(ExtraFieldsNotAllowedBaseModel): name: PeftType - task_type: TaskType = TaskType.CAUSAL_LM + config: PeftConfig class LoraSettings(BasePeftSettings): name: Literal[PeftType.LORA] = PeftType.LORA - r: int = 16 - lora_alpha: int = 16 - lora_dropout: float = 0.05 - bias: str = 'none' - target_modules: list[str] = ['q_proj', 'v_proj'] - modules_to_save: list[str] | None = None + config: LoraConfig class PrefixTuningSettings(BasePeftSettings): name: Literal[PeftType.PREFIX_TUNING] = PeftType.PREFIX_TUNING - encoder_hidden_size: int - prefix_projection: bool + config: PrefixTuningConfig class PromptTuningSettings(BasePeftSettings): name: Literal[PeftType.PROMPT_TUNING] = PeftType.PROMPT_TUNING - num_virtual_tokens: int = 32 - prompt_tuning_init_text: str | None = None + config: PromptTuningConfig class PTuningSettings(BasePeftSettings): name: Literal[PeftType.P_TUNING] = PeftType.P_TUNING - num_virtual_tokens: int = 32 - encoder_reparameterization_type: str = 'MLP' + config: PromptEncoderConfig PEFT_TYPE = PrefixTuningSettings | LoraSettings | PromptTuningSettings | PTuningSettings diff --git a/turbo_alignment/trainers/classification.py b/turbo_alignment/trainers/classification.py index 76a26e4..bfc1500 100755 --- a/turbo_alignment/trainers/classification.py +++ b/turbo_alignment/trainers/classification.py @@ -60,7 +60,7 @@ def auto_class_weights(dataset: Dataset) -> list[float]: class ClassificationTrainer(MultiGPUCherryPicksTrainer): def __init__(self, **kwargs) -> None: args = kwargs.get('args') - self.loss_settings = args.loss_settings + self.loss_settings = args.loss_settings # type: ignore[union-attr] super().__init__( compute_metrics=compute_clf_metrics, **kwargs, diff --git a/turbo_alignment/trainers/dpo.py b/turbo_alignment/trainers/dpo.py index a69da2e..744b17b 100755 --- a/turbo_alignment/trainers/dpo.py +++ b/turbo_alignment/trainers/dpo.py @@ -1,5 +1,5 @@ from collections import defaultdict -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Any, Callable, Literal import torch @@ -414,12 +414,8 @@ class DPOTrainingArguments(TrainingArguments): | SigmoidLossWithMarginSettings | APOZeroLossSettings | APODownLossSettings - ) = field( - default_factory=SigmoidLossSettings(loss_type=DPOLossesType.SIGMOID) - ) # type: ignore[call-overload] - sync_ref_settings: SyncRefModelSettings = field( # type: ignore[call-overload] - default_factory=SyncRefModelSettings() - ) + ) = SigmoidLossSettings(loss_type=DPOLossesType.SIGMOID) + sync_ref_settings: SyncRefModelSettings = SyncRefModelSettings() use_ref_model: bool = True use_sft_model: bool = False average_log_prob: bool = False diff --git a/turbo_alignment/trainers/kto.py b/turbo_alignment/trainers/kto.py index 4b35bca..6a11774 100755 --- a/turbo_alignment/trainers/kto.py +++ b/turbo_alignment/trainers/kto.py @@ -1,5 +1,5 @@ from collections import defaultdict -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Any, Callable, Literal import torch @@ -30,9 +30,7 @@ @dataclass class KTOTrainingArguments(TrainingArguments): beta: float = 0.1 - sync_ref_settings: SyncRefModelSettings = field( - default_factory=SyncRefModelSettings() - ) # type: ignore[call-overload] + sync_ref_settings: SyncRefModelSettings = SyncRefModelSettings() use_ref_model: bool = True average_log_prob: bool = False undesirable_weight: float = 1.0 diff --git a/tutorials/multimodal/multimodal.json b/tutorials/multimodal/multimodal.json index 54b7948..7746c79 100644 --- a/tutorials/multimodal/multimodal.json +++ b/tutorials/multimodal/multimodal.json @@ -76,19 +76,18 @@ "model_type": "causal", "transformers_settings": {}, "peft_settings": { - "r": 16, - "lora_alpha": 16, - "lora_dropout": 0.05, - "target_modules": [ - "q_proj", - "k_proj" - ], - "task_type": "CAUSAL_LM", - "modules_to_save": [ - "embed_tokens", - "lm_head" - ], - "name": "LORA" + "name": "LORA", + "config": { + "r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "target_modules": [ + "q_proj", + "v_proj" + ], + "task_type": "CAUSAL_LM", + "modules_to_save": ["embed_tokens", "lm_head"] + } }, "embeddings_initialization_strategy": { "": "bot",