diff --git a/tests/cli/test_classification.py b/tests/cli/test_classification.py
index 8de6b4b..115e475 100755
--- a/tests/cli/test_classification.py
+++ b/tests/cli/test_classification.py
@@ -1,26 +1,24 @@
from pathlib import Path
+import pytest
from typer.testing import CliRunner
from tests.constants import FIXTURES_PATH
from turbo_alignment.cli import app
+from turbo_alignment.settings.pipelines.train.classification import (
+ ClassificationTrainExperimentSettings,
+)
runner = CliRunner()
-def test_classification_train():
- result = runner.invoke(
- app,
- [
- 'train_classification',
- '--experiment_settings_path',
- FIXTURES_PATH / 'configs/train/classification/base.json',
- ],
- catch_exceptions=False,
- )
+@pytest.mark.parametrize(
+ 'config_path',
+ [
+ FIXTURES_PATH / 'configs/train/classification/base.json',
+ ],
+)
+def test_classification_train(config_path: Path):
+ result = runner.invoke(app, ['train_classification', '--experiment_settings_path', str(config_path)])
assert result.exit_code == 0
- assert Path('test_train_classification_output').is_dir()
-
-
-if __name__ == '__main__':
- test_classification_train()
+ assert ClassificationTrainExperimentSettings.parse_file(config_path).log_path.is_dir()
diff --git a/tests/cli/test_dpo_train.py b/tests/cli/test_dpo_train.py
index f652a64..23bd4af 100755
--- a/tests/cli/test_dpo_train.py
+++ b/tests/cli/test_dpo_train.py
@@ -1,26 +1,26 @@
-# from pathlib import Path
+from pathlib import Path
-# import pytest
-# from typer.testing import CliRunner
+import pytest
+from typer.testing import CliRunner
-# from tests.constants import FIXTURES_PATH
-# from turbo_alignment.cli import app
-# from turbo_alignment.settings.pipelines.train.dpo import DPOTrainExperimentSettings
+from tests.constants import FIXTURES_PATH
+from turbo_alignment.cli import app
+from turbo_alignment.settings.pipelines.train.dpo import DPOTrainExperimentSettings
-# runner = CliRunner()
+runner = CliRunner()
-# @pytest.mark.parametrize(
-# 'config_path',
-# [
-# FIXTURES_PATH / 'configs/train/dpo/base.json',
-# FIXTURES_PATH / 'configs/train/dpo/simpo.json',
-# ],
-# )
-# def test_dpo_train(config_path: Path):
-# result = runner.invoke(
-# app,
-# ['train_dpo', '--experiment_settings_path', str(config_path)],
-# )
-# assert result.exit_code == 0
-# assert DPOTrainExperimentSettings.parse_file(config_path).log_path.is_dir()
+@pytest.mark.parametrize(
+ 'config_path',
+ [
+ FIXTURES_PATH / 'configs/train/dpo/base.json',
+ FIXTURES_PATH / 'configs/train/dpo/simpo.json',
+ ],
+)
+def test_dpo_train(config_path: Path):
+ result = runner.invoke(
+ app,
+ ['train_dpo', '--experiment_settings_path', str(config_path)],
+ )
+ assert result.exit_code == 0
+ assert DPOTrainExperimentSettings.parse_file(config_path).log_path.is_dir()
diff --git a/tests/cli/test_kto_train.py b/tests/cli/test_kto_train.py
index b7ddb7d..18cb27d 100755
--- a/tests/cli/test_kto_train.py
+++ b/tests/cli/test_kto_train.py
@@ -1,23 +1,23 @@
-# from pathlib import Path
+from pathlib import Path
-# import pytest
-# from typer.testing import CliRunner
+import pytest
+from typer.testing import CliRunner
-# from tests.constants import FIXTURES_PATH
-# from turbo_alignment.cli import app
-# from turbo_alignment.settings.pipelines.train.kto import KTOTrainExperimentSettings
+from tests.constants import FIXTURES_PATH
+from turbo_alignment.cli import app
+from turbo_alignment.settings.pipelines.train.kto import KTOTrainExperimentSettings
-# runner = CliRunner()
+runner = CliRunner()
-# @pytest.mark.parametrize(
-# 'config_path',
-# [FIXTURES_PATH / 'configs/train/kto/base.json'],
-# )
-# def test_dpo_train(config_path: Path):
-# result = runner.invoke(
-# app,
-# ['train_kto', '--experiment_settings_path', str(config_path)],
-# )
-# assert result.exit_code == 0
-# assert KTOTrainExperimentSettings.parse_file(config_path).log_path.is_dir()
+@pytest.mark.parametrize(
+ 'config_path',
+ [FIXTURES_PATH / 'configs/train/kto/base.json'],
+)
+def test_dpo_train(config_path: Path):
+ result = runner.invoke(
+ app,
+ ['train_kto', '--experiment_settings_path', str(config_path)],
+ )
+ assert result.exit_code == 0
+ assert KTOTrainExperimentSettings.parse_file(config_path).log_path.is_dir()
diff --git a/tests/conftest.py b/tests/conftest.py
index da98770..371a7f6 100755
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,96 +1,96 @@
-# import json
+import json
-# from pytest import fixture
-# from transformers import AutoTokenizer
+from pytest import fixture
+from transformers import AutoTokenizer
-# from turbo_alignment.dataset.registry import DatasetRegistry
-# from turbo_alignment.settings.datasets.base import (
-# DatasetSourceSettings,
-# DatasetStrategy,
-# DatasetType,
-# )
-# from turbo_alignment.settings.datasets.chat import (
-# ChatDatasetSettings,
-# ChatPromptTemplate,
-# )
-# from turbo_alignment.settings.datasets.pair_preference import (
-# PairPreferenceDatasetSettings,
-# )
+from turbo_alignment.dataset.registry import DatasetRegistry
+from turbo_alignment.settings.datasets.base import (
+ DatasetSourceSettings,
+ DatasetStrategy,
+ DatasetType,
+)
+from turbo_alignment.settings.datasets.chat import (
+ ChatDatasetSettings,
+ ChatPromptTemplate,
+)
+from turbo_alignment.settings.datasets.pair_preference import (
+ PairPreferenceDatasetSettings,
+)
-# @fixture(scope='session')
-# def tokenizer_llama2():
-# tokenizer_path = 'tests/fixtures/models/llama2_classification/tokenizer'
-# tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
-# return tokenizer
+@fixture(scope='session')
+def tokenizer_llama2():
+ tokenizer_path = 'tests/fixtures/models/llama2_classification/tokenizer'
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
+ return tokenizer
-# @fixture(scope='session')
-# def tokenizer_gptj():
-# tokenizer_path = 'tests/fixtures/models/gptj_tiny_for_seq_cls'
-# tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
-# return tokenizer
+@fixture(scope='session')
+def tokenizer_gptj():
+ tokenizer_path = 'tests/fixtures/models/gptj_tiny_for_seq_cls'
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
+ return tokenizer
-# @fixture(scope='session')
-# def chat_dataset_settings():
-# chat_dataset_settings = ChatDatasetSettings(
-# prompt_template=ChatPromptTemplate(
-# prefix_template='',
-# suffix_template='',
-# role_tag_mapping={'user': 'USER', 'bot': 'BOT', 'system': 'SYSTEM'},
-# ),
-# max_tokens_count=8256,
-# )
-# return chat_dataset_settings
+@fixture(scope='session')
+def chat_dataset_settings():
+ chat_dataset_settings = ChatDatasetSettings(
+ prompt_template=ChatPromptTemplate(
+ prefix_template='',
+ suffix_template='',
+ role_tag_mapping={'user': 'USER', 'bot': 'BOT', 'system': 'SYSTEM'},
+ ),
+ max_tokens_count=8256,
+ )
+ return chat_dataset_settings
-# @fixture(scope='session')
-# def classification_dataset_path() -> str:
-# return 'tests/fixtures/datasets/classification/train_classification.jsonl'
+@fixture(scope='session')
+def classification_dataset_path() -> str:
+ return 'tests/fixtures/datasets/classification/train_classification.jsonl'
-# @fixture(scope='session')
-# def pair_preferences_dataset_path() -> str:
-# return 'tests/fixtures/datasets/rm/train_preferences.jsonl'
+@fixture(scope='session')
+def pair_preferences_dataset_path() -> str:
+ return 'tests/fixtures/datasets/rm/train_preferences.jsonl'
-# @fixture(scope='session')
-# def kto_dataset_path() -> str:
-# return 'tests/fixtures/datasets/rm/train_kto.jsonl'
+@fixture(scope='session')
+def kto_dataset_path() -> str:
+ return 'tests/fixtures/datasets/rm/train_kto.jsonl'
-# def load_dataset_source(dataset_path: str) -> tuple[DatasetSourceSettings, list[dict]]:
-# with open(dataset_path, 'r', encoding='utf-8') as f:
-# data_dicts = [json.loads(line) for line in f]
+def load_dataset_source(dataset_path: str) -> tuple[DatasetSourceSettings, list[dict]]:
+ with open(dataset_path, 'r', encoding='utf-8') as f:
+ data_dicts = [json.loads(line) for line in f]
-# source = DatasetSourceSettings(name='dataset_for_test', records_path=dataset_path, sample_rate=1)
+ source = DatasetSourceSettings(name='dataset_for_test', records_path=dataset_path, sample_rate=1)
-# return source, data_dicts
+ return source, data_dicts
-# @fixture(scope='session')
-# def classification_dataset_source(classification_dataset_path) -> tuple[DatasetSourceSettings, list[dict]]:
-# return load_dataset_source(classification_dataset_path)
+@fixture(scope='session')
+def classification_dataset_source(classification_dataset_path) -> tuple[DatasetSourceSettings, list[dict]]:
+ return load_dataset_source(classification_dataset_path)
-# @fixture(scope='session')
-# def pair_preferences_dataset_source(pair_preferences_dataset_path) -> tuple[DatasetSourceSettings, list[dict]]:
-# return load_dataset_source(pair_preferences_dataset_path)
+@fixture(scope='session')
+def pair_preferences_dataset_source(pair_preferences_dataset_path) -> tuple[DatasetSourceSettings, list[dict]]:
+ return load_dataset_source(pair_preferences_dataset_path)
-# @fixture(scope='session')
-# def kto_dataset_source(kto_dataset_path) -> tuple[DatasetSourceSettings, list[dict]]:
-# return load_dataset_source(kto_dataset_path)
+@fixture(scope='session')
+def kto_dataset_source(kto_dataset_path) -> tuple[DatasetSourceSettings, list[dict]]:
+ return load_dataset_source(kto_dataset_path)
-# @fixture(scope='session')
-# def dpo_dataset(pair_preferences_dataset_source, tokenizer_llama2, chat_dataset_settings):
-# source, _ = pair_preferences_dataset_source
+@fixture(scope='session')
+def dpo_dataset(pair_preferences_dataset_source, tokenizer_llama2, chat_dataset_settings):
+ source, _ = pair_preferences_dataset_source
-# dataset_cls = DatasetRegistry.by_name(DatasetType.PAIR_PREFERENCES).by_name(DatasetStrategy.TRAIN)
+ dataset_cls = DatasetRegistry.by_name(DatasetType.PAIR_PREFERENCES).by_name(DatasetStrategy.TRAIN)
-# dataset_settings = PairPreferenceDatasetSettings(chat_settings=chat_dataset_settings)
-# dataset = dataset_cls(tokenizer=tokenizer_llama2, source=source, settings=dataset_settings)
+ dataset_settings = PairPreferenceDatasetSettings(chat_settings=chat_dataset_settings)
+ dataset = dataset_cls(tokenizer=tokenizer_llama2, source=source, settings=dataset_settings)
-# return dataset
+ return dataset
diff --git a/tests/fixtures/configs/train/classification/base.json b/tests/fixtures/configs/train/classification/base.json
index 960188b..eb396fe 100755
--- a/tests/fixtures/configs/train/classification/base.json
+++ b/tests/fixtures/configs/train/classification/base.json
@@ -79,11 +79,18 @@
"problem_type": "single_label_classification"
},
"peft_settings": {
- "r": 8,
- "lora_alpha": 16,
- "lora_dropout": 0.05,
- "target_modules": ["q_proj", "v_proj"],
- "task_type": "SEQ_CLS"
+ "name": "LORA",
+ "config": {
+ "r": 8,
+ "lora_alpha": 16,
+ "lora_dropout": 0.05,
+ "target_modules": [
+ "q_proj",
+ "v_proj"
+ ],
+ "task_type": "CAUSAL_LM",
+ "modules_to_save": ["embed_tokens", "lm_head"]
+ }
}
},
"tokenizer_settings": {},
diff --git a/tests/fixtures/configs/train/multimodal/llama_c_abs_clip_pickle.json b/tests/fixtures/configs/train/multimodal/llama_c_abs_clip_pickle.json
index ffb2b13..6a03d75 100644
--- a/tests/fixtures/configs/train/multimodal/llama_c_abs_clip_pickle.json
+++ b/tests/fixtures/configs/train/multimodal/llama_c_abs_clip_pickle.json
@@ -76,19 +76,18 @@
"model_type": "causal",
"transformers_settings": {},
"peft_settings": {
- "r": 16,
- "lora_alpha": 16,
- "lora_dropout": 0.05,
- "target_modules": [
- "q_proj",
- "k_proj"
- ],
- "task_type": "CAUSAL_LM",
- "modules_to_save": [
- "embed_tokens",
- "lm_head"
- ],
- "name": "LORA"
+ "name": "LORA",
+ "config": {
+ "r": 8,
+ "lora_alpha": 16,
+ "lora_dropout": 0.05,
+ "target_modules": [
+ "q_proj",
+ "v_proj"
+ ],
+ "task_type": "CAUSAL_LM",
+ "modules_to_save": ["embed_tokens", "lm_head"]
+ }
},
"embeddings_initialization_strategy": {
"": "bot",
diff --git a/tests/fixtures/configs/train/multimodal/llama_llava_base_clip.json b/tests/fixtures/configs/train/multimodal/llama_llava_base_clip.json
index ef90914..cd1b153 100644
--- a/tests/fixtures/configs/train/multimodal/llama_llava_base_clip.json
+++ b/tests/fixtures/configs/train/multimodal/llama_llava_base_clip.json
@@ -76,19 +76,18 @@
"model_type": "causal",
"transformers_settings": {},
"peft_settings": {
- "r": 16,
- "lora_alpha": 16,
- "lora_dropout": 0.05,
- "target_modules": [
- "q_proj",
- "k_proj"
- ],
- "task_type": "CAUSAL_LM",
- "modules_to_save": [
- "embed_tokens",
- "lm_head"
- ],
- "name": "LORA"
+ "name": "LORA",
+ "config": {
+ "r": 8,
+ "lora_alpha": 16,
+ "lora_dropout": 0.05,
+ "target_modules": [
+ "q_proj",
+ "v_proj"
+ ],
+ "task_type": "CAUSAL_LM",
+ "modules_to_save": ["embed_tokens", "lm_head"]
+ }
},
"embeddings_initialization_strategy": {
"": "bot",
diff --git a/tests/fixtures/configs/train/multimodal/llama_llava_clip_pickle.json b/tests/fixtures/configs/train/multimodal/llama_llava_clip_pickle.json
index eae875a..1dc8f63 100644
--- a/tests/fixtures/configs/train/multimodal/llama_llava_clip_pickle.json
+++ b/tests/fixtures/configs/train/multimodal/llama_llava_clip_pickle.json
@@ -76,19 +76,18 @@
"model_type": "causal",
"transformers_settings": {},
"peft_settings": {
- "r": 16,
- "lora_alpha": 16,
- "lora_dropout": 0.05,
- "target_modules": [
- "q_proj",
- "k_proj"
- ],
- "task_type": "CAUSAL_LM",
- "modules_to_save": [
- "embed_tokens",
- "lm_head"
- ],
- "name": "LORA"
+ "name": "LORA",
+ "config": {
+ "r": 8,
+ "lora_alpha": 16,
+ "lora_dropout": 0.05,
+ "target_modules": [
+ "q_proj",
+ "v_proj"
+ ],
+ "task_type": "CAUSAL_LM",
+ "modules_to_save": ["embed_tokens", "lm_head"]
+ }
},
"embeddings_initialization_strategy": {
"": "bot",
diff --git a/tests/fixtures/configs/train/rag/base.json b/tests/fixtures/configs/train/rag/base.json
index b64b826..d5f1e4c 100755
--- a/tests/fixtures/configs/train/rag/base.json
+++ b/tests/fixtures/configs/train/rag/base.json
@@ -54,16 +54,18 @@
"": "system"
},
"peft_settings": {
- "r": 16,
- "lora_alpha": 16,
- "lora_dropout": 0.05,
- "target_modules": ["q_proj", "v_proj", "k_proj", "o_proj"],
- "task_type": "CAUSAL_LM",
- "modules_to_save": [
- "embed_tokens",
- "lm_head"
- ],
- "name": "LORA"
+ "name": "LORA",
+ "config": {
+ "r": 8,
+ "lora_alpha": 16,
+ "lora_dropout": 0.05,
+ "target_modules": [
+ "q_proj",
+ "v_proj"
+ ],
+ "task_type": "CAUSAL_LM",
+ "modules_to_save": ["embed_tokens", "lm_head"]
+ }
}
},
"question_encoder_settings": {
diff --git a/tests/fixtures/configs/train/rm/base.json b/tests/fixtures/configs/train/rm/base.json
index c2dd749..c9a3a6b 100755
--- a/tests/fixtures/configs/train/rm/base.json
+++ b/tests/fixtures/configs/train/rm/base.json
@@ -81,11 +81,18 @@
"return_dict": true
},
"peft_settings": {
- "r": 8,
- "lora_alpha": 16,
- "lora_dropout": 0.05,
- "target_modules": ["q_proj", "v_proj"],
- "task_type": "SEQ_CLS"
+ "name": "LORA",
+ "config": {
+ "r": 8,
+ "lora_alpha": 16,
+ "lora_dropout": 0.05,
+ "target_modules": [
+ "q_proj",
+ "v_proj"
+ ],
+ "task_type": "SEQ_CLS",
+ "modules_to_save": ["embed_tokens", "lm_head"]
+ }
}
},
"tokenizer_settings": {},
diff --git a/tests/fixtures/configs/train/sft/base.json b/tests/fixtures/configs/train/sft/base.json
index 99c1ac0..c83f08b 100755
--- a/tests/fixtures/configs/train/sft/base.json
+++ b/tests/fixtures/configs/train/sft/base.json
@@ -47,16 +47,18 @@
"transformers_settings": {
},
"peft_settings": {
- "r": 8,
- "lora_alpha": 16,
- "lora_dropout": 0.05,
- "target_modules": [
- "q_proj",
- "v_proj"
- ],
- "task_type": "CAUSAL_LM",
- "modules_to_save": ["embed_tokens", "lm_head"],
- "name": "LORA"
+ "name": "LORA",
+ "config": {
+ "r": 8,
+ "lora_alpha": 16,
+ "lora_dropout": 0.05,
+ "target_modules": [
+ "q_proj",
+ "v_proj"
+ ],
+ "task_type": "CAUSAL_LM",
+ "modules_to_save": ["embed_tokens", "lm_head"]
+ }
},
"embeddings_initialization_strategy": {
"": "",
diff --git a/tests/fixtures/configs/train/sft/prompt_tuning.json b/tests/fixtures/configs/train/sft/prompt_tuning.json
index 920d1f3..fc58e42 100755
--- a/tests/fixtures/configs/train/sft/prompt_tuning.json
+++ b/tests/fixtures/configs/train/sft/prompt_tuning.json
@@ -46,9 +46,11 @@
"model_type": "causal",
"transformers_settings": {},
"peft_settings": {
- "task_type": "CAUSAL_LM",
"name": "PROMPT_TUNING",
- "num_virtual_tokens": 32
+ "config": {
+ "task_type": "CAUSAL_LM",
+ "num_virtual_tokens": 32
+ }
},
"embeddings_initialization_strategy": {
"": "",
diff --git a/tests/fixtures/configs/train/sft/sft_with_rm_metric.json b/tests/fixtures/configs/train/sft/sft_with_rm_metric.json
index 8807122..ec6526c 100755
--- a/tests/fixtures/configs/train/sft/sft_with_rm_metric.json
+++ b/tests/fixtures/configs/train/sft/sft_with_rm_metric.json
@@ -47,16 +47,18 @@
"transformers_settings": {
},
"peft_settings": {
- "r": 8,
- "lora_alpha": 16,
- "lora_dropout": 0.05,
- "target_modules": [
- "q_proj",
- "v_proj"
- ],
- "task_type": "CAUSAL_LM",
- "modules_to_save": ["embed_tokens", "lm_head"],
- "name": "LORA"
+ "name": "LORA",
+ "config": {
+ "r": 8,
+ "lora_alpha": 16,
+ "lora_dropout": 0.05,
+ "target_modules": [
+ "q_proj",
+ "v_proj"
+ ],
+ "task_type": "CAUSAL_LM",
+ "modules_to_save": ["embed_tokens", "lm_head"]
+ }
},
"embeddings_initialization_strategy": {
"": "",
diff --git a/tests/integration/test_trainers.py b/tests/integration/test_trainers.py
index 56941bb..97f8aa8 100644
--- a/tests/integration/test_trainers.py
+++ b/tests/integration/test_trainers.py
@@ -1,69 +1,69 @@
-# from tempfile import TemporaryDirectory
+from tempfile import TemporaryDirectory
-# import torch
-# from transformers import AutoModelForCausalLM, AutoTokenizer
+import torch
+from transformers import AutoModelForCausalLM, AutoTokenizer
-# from turbo_alignment.common.tf.callbacks.sync_ref_model import SyncRefModelSettings
-# from turbo_alignment.dataset.pair_preferences import PairPreferenceDataCollator
-# from turbo_alignment.settings.pipelines.train.dpo import (
-# DPOLossesType,
-# SigmoidLossSettings,
-# )
-# from turbo_alignment.trainers.dpo import DPOTrainer, DPOTrainingArguments
+from turbo_alignment.common.tf.callbacks.sync_ref_model import SyncRefModelSettings
+from turbo_alignment.dataset.pair_preferences import PairPreferenceDataCollator
+from turbo_alignment.settings.pipelines.train.utils import (
+ DPOLossesType,
+ SigmoidLossSettings,
+)
+from turbo_alignment.trainers.dpo import DPOTrainer, DPOTrainingArguments
-# def test_dpo_trainer(dpo_dataset):
-# model_path = 'tests/fixtures/models/llama2_tiny'
+def test_dpo_trainer(dpo_dataset):
+ model_path = 'tests/fixtures/models/llama2_tiny'
-# model = AutoModelForCausalLM.from_pretrained(model_path)
+ model = AutoModelForCausalLM.from_pretrained(model_path)
-# ref_model = AutoModelForCausalLM.from_pretrained(model_path)
+ ref_model = AutoModelForCausalLM.from_pretrained(model_path)
-# with TemporaryDirectory() as tmp_dir:
-# args = DPOTrainingArguments(
-# do_train=True,
-# loss_settings=SigmoidLossSettings(loss_type=DPOLossesType.SIGMOID).dict(),
-# sync_ref_settings=SyncRefModelSettings().dict(),
-# do_eval=False,
-# learning_rate=1.0e-4,
-# use_cpu=True,
-# num_train_epochs=10,
-# report_to=[],
-# remove_unused_columns=False,
-# output_dir=tmp_dir,
-# )
+ with TemporaryDirectory() as tmp_dir:
+ args = DPOTrainingArguments(
+ do_train=True,
+ loss_settings=SigmoidLossSettings(loss_type=DPOLossesType.SIGMOID).dict(),
+ sync_ref_settings=SyncRefModelSettings().dict(),
+ do_eval=False,
+ learning_rate=1.0e-4,
+ use_cpu=True,
+ num_train_epochs=10,
+ report_to=[],
+ remove_unused_columns=False,
+ output_dir=tmp_dir,
+ )
-# tokenizer = AutoTokenizer.from_pretrained(model_path)
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
-# data_collator = PairPreferenceDataCollator(tokenizer=tokenizer)
+ data_collator = PairPreferenceDataCollator(tokenizer=tokenizer)
-# trainer = DPOTrainer(
-# model=model,
-# ref_model=ref_model,
-# args=args,
-# train_dataset=dpo_dataset,
-# eval_dataset=dpo_dataset,
-# data_collator=data_collator,
-# )
+ trainer = DPOTrainer(
+ model=model,
+ ref_model=ref_model,
+ args=args,
+ train_dataset=dpo_dataset,
+ eval_dataset=dpo_dataset,
+ data_collator=data_collator,
+ )
-# batch = data_collator(list(dpo_dataset))
+ batch = data_collator(list(dpo_dataset))
-# loss_before, _ = trainer.get_batch_metrics(model, batch, 'train')
-# trainer.train()
-# loss_after, _ = trainer.get_batch_metrics(trainer.model, batch, 'train')
+ loss_before, _ = trainer.get_batch_metrics(model, batch, 'train')
+ trainer.train()
+ loss_after, _ = trainer.get_batch_metrics(trainer.model, batch, 'train')
-# assert torch.greater(loss_before, loss_after)
+ assert torch.greater(loss_before, loss_after)
-# initial_model = AutoModelForCausalLM.from_pretrained(model_path)
+ initial_model = AutoModelForCausalLM.from_pretrained(model_path)
-# trainer.save_model(tmp_dir)
-# trained_model = AutoModelForCausalLM.from_pretrained(tmp_dir)
+ trainer.save_model(tmp_dir)
+ trained_model = AutoModelForCausalLM.from_pretrained(tmp_dir)
-# initial_state_dict = initial_model.state_dict()
-# trained_state_dict = trained_model.state_dict()
-# for k, v in trained_state_dict.items():
-# assert any(torch.not_equal(v, initial_state_dict[k]).tolist())
+ initial_state_dict = initial_model.state_dict()
+ trained_state_dict = trained_model.state_dict()
+ for k, v in trained_state_dict.items():
+ assert any(torch.not_equal(v, initial_state_dict[k]).tolist())
-# ref_model_state_dict = trainer.ref_model.state_dict()
-# for k, v in ref_model_state_dict.items():
-# assert all(torch.eq(v, initial_state_dict[k]).tolist())
+ ref_model_state_dict = trainer.ref_model.state_dict()
+ for k, v in ref_model_state_dict.items():
+ assert all(torch.eq(v, initial_state_dict[k]).tolist())
diff --git a/tests/unit/test_collators.py b/tests/unit/test_collators.py
index 5231b40..150436a 100644
--- a/tests/unit/test_collators.py
+++ b/tests/unit/test_collators.py
@@ -1,28 +1,28 @@
-# from tests.utils import is_sample_build_from_content
-# from turbo_alignment.dataset.kto.collators import KTODataCollator
-# from turbo_alignment.dataset.registry import DatasetRegistry
-# from turbo_alignment.settings.datasets.base import DatasetStrategy, DatasetType
-# from turbo_alignment.settings.datasets.kto import KTODatasetSettings
+from tests.utils import is_sample_build_from_content
+from turbo_alignment.dataset.kto.collators import KTODataCollator
+from turbo_alignment.dataset.registry import DatasetRegistry
+from turbo_alignment.settings.datasets.base import DatasetStrategy, DatasetType
+from turbo_alignment.settings.datasets.kto import KTODatasetSettings
-# def test_kto_collator(tokenizer_llama2, chat_dataset_settings, kto_dataset_source):
-# tokenizer = tokenizer_llama2
-# source, data_dicts = kto_dataset_source
+def test_kto_collator(tokenizer_llama2, chat_dataset_settings, kto_dataset_source):
+ tokenizer = tokenizer_llama2
+ source, data_dicts = kto_dataset_source
-# dataset_cls = DatasetRegistry.by_name(DatasetType.KTO).by_name(DatasetStrategy.TRAIN)
+ dataset_cls = DatasetRegistry.by_name(DatasetType.KTO).by_name(DatasetStrategy.TRAIN)
-# dataset_settings = KTODatasetSettings(chat_settings=chat_dataset_settings)
-# dataset = dataset_cls(tokenizer=tokenizer, source=source, settings=dataset_settings)
+ dataset_settings = KTODatasetSettings(chat_settings=chat_dataset_settings)
+ dataset = dataset_cls(tokenizer=tokenizer, source=source, settings=dataset_settings)
-# batch_size = min(len(dataset), 8)
-# examples = list(dataset)[:batch_size]
+ batch_size = min(len(dataset), 8)
+ examples = list(dataset)[:batch_size]
-# collator = KTODataCollator(tokenizer=tokenizer)
-# batch = collator(examples)
+ collator = KTODataCollator(tokenizer=tokenizer)
+ batch = collator(examples)
-# ignore_index = -100
-# for answer_labels, is_desirable, raw_data in zip(batch['labels'], batch['is_desirable'], data_dicts):
-# answer = raw_data['answer']['content']
-# answer_tokens = answer_labels[answer_labels != ignore_index]
-# assert is_sample_build_from_content(answer_tokens, [answer], tokenizer)
-# assert raw_data['is_desirable'] == is_desirable
+ ignore_index = -100
+ for answer_labels, is_desirable, raw_data in zip(batch['labels'], batch['is_desirable'], data_dicts):
+ answer = raw_data['answer']['content']
+ answer_tokens = answer_labels[answer_labels != ignore_index]
+ assert is_sample_build_from_content(answer_tokens, [answer], tokenizer)
+ assert raw_data['is_desirable'] == is_desirable
diff --git a/tests/unit/test_datasets.py b/tests/unit/test_datasets.py
index 089417b..1ecd323 100644
--- a/tests/unit/test_datasets.py
+++ b/tests/unit/test_datasets.py
@@ -1,88 +1,88 @@
-# from tests.utils import is_sample_build_from_content
-# from turbo_alignment.dataset.classification.models import ClassificationDatasetRecord
-# from turbo_alignment.dataset.pair_preferences.models import PairPreferenceRecord
-# from turbo_alignment.dataset.registry import DatasetRegistry
-# from turbo_alignment.settings.datasets.base import DatasetStrategy, DatasetType
-# from turbo_alignment.settings.datasets.classification import (
-# ClassificationDatasetSettings,
-# )
-# from turbo_alignment.settings.datasets.ddpo import DDPODatasetSettings
-# from turbo_alignment.settings.datasets.pair_preference import (
-# PairPreferenceDatasetSettings,
-# )
+from tests.utils import is_sample_build_from_content
+from turbo_alignment.dataset.classification.models import ClassificationDatasetRecord
+from turbo_alignment.dataset.pair_preferences.models import PairPreferenceRecord
+from turbo_alignment.dataset.registry import DatasetRegistry
+from turbo_alignment.settings.datasets.base import DatasetStrategy, DatasetType
+from turbo_alignment.settings.datasets.classification import (
+ ClassificationDatasetSettings,
+)
+from turbo_alignment.settings.datasets.ddpo import DDPODatasetSettings
+from turbo_alignment.settings.datasets.pair_preference import (
+ PairPreferenceDatasetSettings,
+)
-# def test_classification(tokenizer_llama2, chat_dataset_settings, classification_dataset_source):
-# # load dataset and check that samples have required fields
+def test_classification(tokenizer_llama2, chat_dataset_settings, classification_dataset_source):
+ # load dataset and check that samples have required fields
-# source, data_dicts = classification_dataset_source
+ source, data_dicts = classification_dataset_source
-# dataset_cls = DatasetRegistry.by_name(DatasetType.CLASSIFICATION).by_name(DatasetStrategy.TRAIN)
+ dataset_cls = DatasetRegistry.by_name(DatasetType.CLASSIFICATION).by_name(DatasetStrategy.TRAIN)
-# dataset_settings = ClassificationDatasetSettings(chat_settings=chat_dataset_settings)
+ dataset_settings = ClassificationDatasetSettings(chat_settings=chat_dataset_settings)
-# dataset = dataset_cls(tokenizer=tokenizer_llama2, source=source, settings=dataset_settings)
+ dataset = dataset_cls(tokenizer=tokenizer_llama2, source=source, settings=dataset_settings)
-# assert len(data_dicts) == len(dataset)
+ assert len(data_dicts) == len(dataset)
-# for data_dict, sample in zip(data_dicts, dataset):
-# record = ClassificationDatasetRecord.model_validate(data_dict)
+ for data_dict, sample in zip(data_dicts, dataset):
+ record = ClassificationDatasetRecord.model_validate(data_dict)
-# assert record.label == sample['labels']
+ assert record.label == sample['labels']
-# assert is_sample_build_from_content(
-# sample['input_ids'], [m.content for m in record.messages], tokenizer_llama2
-# )
+ assert is_sample_build_from_content(
+ sample['input_ids'], [m.content for m in record.messages], tokenizer_llama2
+ )
-# def test_pair_preferences(tokenizer_llama2, chat_dataset_settings, pair_preferences_dataset_source):
-# # load dataset and check that samples have required fields
+def test_pair_preferences(tokenizer_llama2, chat_dataset_settings, pair_preferences_dataset_source):
+ # load dataset and check that samples have required fields
-# source, data_dicts = pair_preferences_dataset_source
+ source, data_dicts = pair_preferences_dataset_source
-# dataset_cls = DatasetRegistry.by_name(DatasetType.PAIR_PREFERENCES).by_name(DatasetStrategy.TRAIN)
+ dataset_cls = DatasetRegistry.by_name(DatasetType.PAIR_PREFERENCES).by_name(DatasetStrategy.TRAIN)
-# dataset_settings = PairPreferenceDatasetSettings(chat_settings=chat_dataset_settings)
-# dataset = dataset_cls(tokenizer=tokenizer_llama2, source=source, settings=dataset_settings)
+ dataset_settings = PairPreferenceDatasetSettings(chat_settings=chat_dataset_settings)
+ dataset = dataset_cls(tokenizer=tokenizer_llama2, source=source, settings=dataset_settings)
-# assert len(data_dicts) == len(dataset)
+ assert len(data_dicts) == len(dataset)
-# for data_dict, sample in zip(data_dicts, dataset):
-# record = PairPreferenceRecord.model_validate(data_dict)
-# context: list[str] = [c.content for c in record.context]
-# contents_w = [*context, record.answer_w.content]
-# assert is_sample_build_from_content(sample['inputs_w']['input_ids'], contents_w, tokenizer_llama2)
+ for data_dict, sample in zip(data_dicts, dataset):
+ record = PairPreferenceRecord.model_validate(data_dict)
+ context: list[str] = [c.content for c in record.context]
+ contents_w = [*context, record.answer_w.content]
+ assert is_sample_build_from_content(sample['inputs_w']['input_ids'], contents_w, tokenizer_llama2)
-# contents_l = [*context, record.answer_l.content]
-# assert is_sample_build_from_content(sample['inputs_l']['input_ids'], contents_l, tokenizer_llama2)
+ contents_l = [*context, record.answer_l.content]
+ assert is_sample_build_from_content(sample['inputs_l']['input_ids'], contents_l, tokenizer_llama2)
-# def test_ddpo(tokenizer_llama2, tokenizer_gptj, chat_dataset_settings, pair_preferences_dataset_source):
-# sft_tokenizer = tokenizer_llama2
-# rm_tokenizer = tokenizer_gptj
-# # load dataset and check that samples have required fields
+def test_ddpo(tokenizer_llama2, tokenizer_gptj, chat_dataset_settings, pair_preferences_dataset_source):
+ sft_tokenizer = tokenizer_llama2
+ rm_tokenizer = tokenizer_gptj
+ # load dataset and check that samples have required fields
-# source, data_dicts = pair_preferences_dataset_source
+ source, data_dicts = pair_preferences_dataset_source
-# dataset_cls = DatasetRegistry.by_name(DatasetType.DDPO).by_name(DatasetStrategy.TRAIN)
+ dataset_cls = DatasetRegistry.by_name(DatasetType.DDPO).by_name(DatasetStrategy.TRAIN)
-# pair_preferences_dataset_settings = PairPreferenceDatasetSettings(chat_settings=chat_dataset_settings)
-# dataset_settings = DDPODatasetSettings(
-# chat_settings=chat_dataset_settings, pair_preferences=pair_preferences_dataset_settings
-# )
-# dataset = dataset_cls(
-# chat_tokenizer=sft_tokenizer, rm_tokenizer=rm_tokenizer, source=source, settings=dataset_settings
-# )
+ pair_preferences_dataset_settings = PairPreferenceDatasetSettings(chat_settings=chat_dataset_settings)
+ dataset_settings = DDPODatasetSettings(
+ chat_settings=chat_dataset_settings, pair_preferences=pair_preferences_dataset_settings
+ )
+ dataset = dataset_cls(
+ chat_tokenizer=sft_tokenizer, rm_tokenizer=rm_tokenizer, source=source, settings=dataset_settings
+ )
-# assert len(data_dicts) == len(dataset)
+ assert len(data_dicts) == len(dataset)
-# for data_dict, sample in zip(data_dicts, dataset):
-# record = PairPreferenceRecord.model_validate(data_dict)
-# context: list[str] = [c.content for c in record.context]
-# contents_w = [*context, record.answer_w.content]
-# assert is_sample_build_from_content(sample['sft_inputs_w']['input_ids'], contents_w, sft_tokenizer)
-# assert is_sample_build_from_content(sample['rm_inputs_w']['input_ids'], contents_w, rm_tokenizer)
+ for data_dict, sample in zip(data_dicts, dataset):
+ record = PairPreferenceRecord.model_validate(data_dict)
+ context: list[str] = [c.content for c in record.context]
+ contents_w = [*context, record.answer_w.content]
+ assert is_sample_build_from_content(sample['sft_inputs_w']['input_ids'], contents_w, sft_tokenizer)
+ assert is_sample_build_from_content(sample['rm_inputs_w']['input_ids'], contents_w, rm_tokenizer)
-# contents_l = [*context, record.answer_l.content]
-# assert is_sample_build_from_content(sample['sft_inputs_l']['input_ids'], contents_l, sft_tokenizer)
-# assert is_sample_build_from_content(sample['rm_inputs_l']['input_ids'], contents_l, rm_tokenizer)
+ contents_l = [*context, record.answer_l.content]
+ assert is_sample_build_from_content(sample['sft_inputs_l']['input_ids'], contents_l, sft_tokenizer)
+ assert is_sample_build_from_content(sample['rm_inputs_l']['input_ids'], contents_l, rm_tokenizer)
diff --git a/turbo_alignment/common/tf/loaders/model/model.py b/turbo_alignment/common/tf/loaders/model/model.py
index aba80e9..49a5f11 100755
--- a/turbo_alignment/common/tf/loaders/model/model.py
+++ b/turbo_alignment/common/tf/loaders/model/model.py
@@ -4,7 +4,6 @@
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from turbo_alignment.common.tf.loaders.model.registry import (
- PeftConfigRegistry,
TransformersAutoModelRegistry,
)
from turbo_alignment.modeling.liger_kernels import apply_liger_kernel_to_gemma2
@@ -18,12 +17,7 @@
def _prepare_model_for_peft(model: PreTrainedModel, peft_settings: PEFT_TYPE) -> PeftModel:
- peft_params = peft_settings.dict()
- peft_params.pop('name')
-
- peft_config = PeftConfigRegistry.by_name(peft_settings.name)(**peft_params)
-
- return get_peft_model(model, peft_config)
+ return get_peft_model(model, peft_settings.config)
def _load_pretrained_adapters(
diff --git a/turbo_alignment/common/tf/loaders/model/registry.py b/turbo_alignment/common/tf/loaders/model/registry.py
index e2c5e08..565a61d 100755
--- a/turbo_alignment/common/tf/loaders/model/registry.py
+++ b/turbo_alignment/common/tf/loaders/model/registry.py
@@ -1,10 +1,3 @@
-from peft import (
- LoraConfig,
- PeftType,
- PrefixTuningConfig,
- PromptEncoderConfig,
- PromptTuningConfig,
-)
from transformers import (
AutoModel,
AutoModelForCausalLM,
@@ -19,15 +12,6 @@ class TransformersAutoModelRegistry(Registrable):
...
-class PeftConfigRegistry(Registrable):
- ...
-
-
TransformersAutoModelRegistry.register(ModelType.CAUSAL)(AutoModelForCausalLM)
TransformersAutoModelRegistry.register(ModelType.SEQ_CLS)(AutoModelForSequenceClassification)
TransformersAutoModelRegistry.register(ModelType.ENC)(AutoModel)
-
-PeftConfigRegistry.register(PeftType.LORA)(LoraConfig)
-PeftConfigRegistry.register(PeftType.PREFIX_TUNING)(PrefixTuningConfig)
-PeftConfigRegistry.register(PeftType.PROMPT_TUNING)(PromptTuningConfig)
-PeftConfigRegistry.register(PeftType.P_TUNING)(PromptEncoderConfig)
diff --git a/turbo_alignment/settings/pipelines/train/dpo.py b/turbo_alignment/settings/pipelines/train/dpo.py
index 0270063..795b84b 100755
--- a/turbo_alignment/settings/pipelines/train/dpo.py
+++ b/turbo_alignment/settings/pipelines/train/dpo.py
@@ -17,8 +17,8 @@ class DPOTrainExperimentSettings(BaseTrainExperimentSettings):
cherry_pick_settings: ChatCherryPickSettings
- # training_arguments: DPOTrainingArguments
+ training_arguments: DPOTrainingArguments
- # @field_validator('training_arguments', mode='before')
- # def convert_generation_settings(cls, values: dict[str, Any]) -> DPOTrainingArguments:
- # return DPOTrainingArguments(**values, output_dir=TRAINER_LOGS_FOLDER, report_to=[])
+ @field_validator('training_arguments', mode='before')
+ def create_training_arguments(cls, values: dict[str, Any]) -> DPOTrainingArguments:
+ return DPOTrainingArguments(**values, output_dir=TRAINER_LOGS_FOLDER, report_to=[])
diff --git a/turbo_alignment/settings/pipelines/train/kto.py b/turbo_alignment/settings/pipelines/train/kto.py
index 0d459ef..bdf09f3 100755
--- a/turbo_alignment/settings/pipelines/train/kto.py
+++ b/turbo_alignment/settings/pipelines/train/kto.py
@@ -1,3 +1,8 @@
+from typing import Any
+
+from pydantic import field_validator
+
+from turbo_alignment.constants import TRAINER_LOGS_FOLDER
from turbo_alignment.settings.cherry_pick import ChatCherryPickSettings
from turbo_alignment.settings.datasets.kto import KTOMultiDatasetSettings
from turbo_alignment.settings.pipelines.train.base import BaseTrainExperimentSettings
@@ -10,4 +15,8 @@ class KTOTrainExperimentSettings(BaseTrainExperimentSettings):
cherry_pick_settings: ChatCherryPickSettings
- # training_arguments: KTOTrainingArguments
+ training_arguments: KTOTrainingArguments
+
+ @field_validator('training_arguments', mode='before')
+ def create_training_arguments(cls, values: dict[str, Any]) -> KTOTrainingArguments:
+ return KTOTrainingArguments(**values, output_dir=TRAINER_LOGS_FOLDER, report_to=[])
diff --git a/turbo_alignment/settings/s3.py b/turbo_alignment/settings/s3.py
index 2e964fc..22c63ca 100755
--- a/turbo_alignment/settings/s3.py
+++ b/turbo_alignment/settings/s3.py
@@ -1,6 +1,5 @@
from pathlib import Path
-from pydantic import field_validator
from pydantic_settings import BaseSettings
@@ -14,13 +13,6 @@ class S3HandlerParameters(BaseSettings):
aws_access_key_id: str
aws_secret_access_key: str
- @field_validator('bucket')
- def bucket_name_biglm_is_not_allowed(cls, values: str) -> str:
- if values == 'biglm':
- raise S3HandlerParametersWrongBucketException('Usage of biglm bucket is not allowed')
-
- return values
-
class Config:
env_file: str = '.env'
env_prefix: str = 'S3_CHECKPOINTS_'
diff --git a/turbo_alignment/settings/tf/peft.py b/turbo_alignment/settings/tf/peft.py
index f479a27..d8aafd1 100755
--- a/turbo_alignment/settings/tf/peft.py
+++ b/turbo_alignment/settings/tf/peft.py
@@ -1,41 +1,40 @@
from typing import Literal
-from peft import PeftType, TaskType
+from peft import (
+ LoraConfig,
+ PeftConfig,
+ PeftType,
+ PrefixTuningConfig,
+ PromptEncoderConfig,
+ PromptTuningConfig,
+)
from turbo_alignment.settings.base import ExtraFieldsNotAllowedBaseModel
class BasePeftSettings(ExtraFieldsNotAllowedBaseModel):
name: PeftType
- task_type: TaskType = TaskType.CAUSAL_LM
+ config: PeftConfig
class LoraSettings(BasePeftSettings):
name: Literal[PeftType.LORA] = PeftType.LORA
- r: int = 16
- lora_alpha: int = 16
- lora_dropout: float = 0.05
- bias: str = 'none'
- target_modules: list[str] = ['q_proj', 'v_proj']
- modules_to_save: list[str] | None = None
+ config: LoraConfig
class PrefixTuningSettings(BasePeftSettings):
name: Literal[PeftType.PREFIX_TUNING] = PeftType.PREFIX_TUNING
- encoder_hidden_size: int
- prefix_projection: bool
+ config: PrefixTuningConfig
class PromptTuningSettings(BasePeftSettings):
name: Literal[PeftType.PROMPT_TUNING] = PeftType.PROMPT_TUNING
- num_virtual_tokens: int = 32
- prompt_tuning_init_text: str | None = None
+ config: PromptTuningConfig
class PTuningSettings(BasePeftSettings):
name: Literal[PeftType.P_TUNING] = PeftType.P_TUNING
- num_virtual_tokens: int = 32
- encoder_reparameterization_type: str = 'MLP'
+ config: PromptEncoderConfig
PEFT_TYPE = PrefixTuningSettings | LoraSettings | PromptTuningSettings | PTuningSettings
diff --git a/turbo_alignment/trainers/classification.py b/turbo_alignment/trainers/classification.py
index 76a26e4..bfc1500 100755
--- a/turbo_alignment/trainers/classification.py
+++ b/turbo_alignment/trainers/classification.py
@@ -60,7 +60,7 @@ def auto_class_weights(dataset: Dataset) -> list[float]:
class ClassificationTrainer(MultiGPUCherryPicksTrainer):
def __init__(self, **kwargs) -> None:
args = kwargs.get('args')
- self.loss_settings = args.loss_settings
+ self.loss_settings = args.loss_settings # type: ignore[union-attr]
super().__init__(
compute_metrics=compute_clf_metrics,
**kwargs,
diff --git a/turbo_alignment/trainers/dpo.py b/turbo_alignment/trainers/dpo.py
index a69da2e..744b17b 100755
--- a/turbo_alignment/trainers/dpo.py
+++ b/turbo_alignment/trainers/dpo.py
@@ -1,5 +1,5 @@
from collections import defaultdict
-from dataclasses import dataclass, field
+from dataclasses import dataclass
from typing import Any, Callable, Literal
import torch
@@ -414,12 +414,8 @@ class DPOTrainingArguments(TrainingArguments):
| SigmoidLossWithMarginSettings
| APOZeroLossSettings
| APODownLossSettings
- ) = field(
- default_factory=SigmoidLossSettings(loss_type=DPOLossesType.SIGMOID)
- ) # type: ignore[call-overload]
- sync_ref_settings: SyncRefModelSettings = field( # type: ignore[call-overload]
- default_factory=SyncRefModelSettings()
- )
+ ) = SigmoidLossSettings(loss_type=DPOLossesType.SIGMOID)
+ sync_ref_settings: SyncRefModelSettings = SyncRefModelSettings()
use_ref_model: bool = True
use_sft_model: bool = False
average_log_prob: bool = False
diff --git a/turbo_alignment/trainers/kto.py b/turbo_alignment/trainers/kto.py
index 4b35bca..6a11774 100755
--- a/turbo_alignment/trainers/kto.py
+++ b/turbo_alignment/trainers/kto.py
@@ -1,5 +1,5 @@
from collections import defaultdict
-from dataclasses import dataclass, field
+from dataclasses import dataclass
from typing import Any, Callable, Literal
import torch
@@ -30,9 +30,7 @@
@dataclass
class KTOTrainingArguments(TrainingArguments):
beta: float = 0.1
- sync_ref_settings: SyncRefModelSettings = field(
- default_factory=SyncRefModelSettings()
- ) # type: ignore[call-overload]
+ sync_ref_settings: SyncRefModelSettings = SyncRefModelSettings()
use_ref_model: bool = True
average_log_prob: bool = False
undesirable_weight: float = 1.0
diff --git a/tutorials/multimodal/multimodal.json b/tutorials/multimodal/multimodal.json
index 54b7948..7746c79 100644
--- a/tutorials/multimodal/multimodal.json
+++ b/tutorials/multimodal/multimodal.json
@@ -76,19 +76,18 @@
"model_type": "causal",
"transformers_settings": {},
"peft_settings": {
- "r": 16,
- "lora_alpha": 16,
- "lora_dropout": 0.05,
- "target_modules": [
- "q_proj",
- "k_proj"
- ],
- "task_type": "CAUSAL_LM",
- "modules_to_save": [
- "embed_tokens",
- "lm_head"
- ],
- "name": "LORA"
+ "name": "LORA",
+ "config": {
+ "r": 8,
+ "lora_alpha": 16,
+ "lora_dropout": 0.05,
+ "target_modules": [
+ "q_proj",
+ "v_proj"
+ ],
+ "task_type": "CAUSAL_LM",
+ "modules_to_save": ["embed_tokens", "lm_head"]
+ }
},
"embeddings_initialization_strategy": {
"": "bot",