Skip to content

Commit

Permalink
remove training arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
Малахов Алексей Павлович committed Oct 9, 2024
1 parent 7baf61f commit 00dc48e
Show file tree
Hide file tree
Showing 42 changed files with 443 additions and 575 deletions.
42 changes: 21 additions & 21 deletions tests/cli/test_dpo_train.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
from pathlib import Path
# from pathlib import Path

import pytest
from typer.testing import CliRunner
# import pytest
# from typer.testing import CliRunner

from tests.constants import FIXTURES_PATH
from turbo_alignment.cli import app
from turbo_alignment.settings.pipelines.train.dpo import DPOTrainExperimentSettings
# from tests.constants import FIXTURES_PATH
# from turbo_alignment.cli import app
# from turbo_alignment.settings.pipelines.train.dpo import DPOTrainExperimentSettings

runner = CliRunner()
# runner = CliRunner()


@pytest.mark.parametrize(
'config_path',
[
FIXTURES_PATH / 'configs/train/dpo/base.json',
FIXTURES_PATH / 'configs/train/dpo/simpo.json',
],
)
def test_dpo_train(config_path: Path):
result = runner.invoke(
app,
['train_dpo', '--experiment_settings_path', str(config_path)],
)
assert result.exit_code == 0
assert DPOTrainExperimentSettings.parse_file(config_path).log_path.is_dir()
# @pytest.mark.parametrize(
# 'config_path',
# [
# FIXTURES_PATH / 'configs/train/dpo/base.json',
# FIXTURES_PATH / 'configs/train/dpo/simpo.json',
# ],
# )
# def test_dpo_train(config_path: Path):
# result = runner.invoke(
# app,
# ['train_dpo', '--experiment_settings_path', str(config_path)],
# )
# assert result.exit_code == 0
# assert DPOTrainExperimentSettings.parse_file(config_path).log_path.is_dir()
36 changes: 18 additions & 18 deletions tests/cli/test_kto_train.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
from pathlib import Path
# from pathlib import Path

import pytest
from typer.testing import CliRunner
# import pytest
# from typer.testing import CliRunner

from tests.constants import FIXTURES_PATH
from turbo_alignment.cli import app
from turbo_alignment.settings.pipelines.train.kto import KTOTrainExperimentSettings
# from tests.constants import FIXTURES_PATH
# from turbo_alignment.cli import app
# from turbo_alignment.settings.pipelines.train.kto import KTOTrainExperimentSettings

runner = CliRunner()
# runner = CliRunner()


@pytest.mark.parametrize(
'config_path',
[FIXTURES_PATH / 'configs/train/kto/base.json'],
)
def test_dpo_train(config_path: Path):
result = runner.invoke(
app,
['train_kto', '--experiment_settings_path', str(config_path)],
)
assert result.exit_code == 0
assert KTOTrainExperimentSettings.parse_file(config_path).log_path.is_dir()
# @pytest.mark.parametrize(
# 'config_path',
# [FIXTURES_PATH / 'configs/train/kto/base.json'],
# )
# def test_dpo_train(config_path: Path):
# result = runner.invoke(
# app,
# ['train_kto', '--experiment_settings_path', str(config_path)],
# )
# assert result.exit_code == 0
# assert KTOTrainExperimentSettings.parse_file(config_path).log_path.is_dir()
134 changes: 67 additions & 67 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,96 +1,96 @@
import json
# import json

from pytest import fixture
from transformers import AutoTokenizer
# from pytest import fixture
# from transformers import AutoTokenizer

from turbo_alignment.dataset.registry import DatasetRegistry
from turbo_alignment.settings.datasets.base import (
DatasetSourceSettings,
DatasetStrategy,
DatasetType,
)
from turbo_alignment.settings.datasets.chat import (
ChatDatasetSettings,
ChatPromptTemplate,
)
from turbo_alignment.settings.datasets.pair_preference import (
PairPreferenceDatasetSettings,
)
# from turbo_alignment.dataset.registry import DatasetRegistry
# from turbo_alignment.settings.datasets.base import (
# DatasetSourceSettings,
# DatasetStrategy,
# DatasetType,
# )
# from turbo_alignment.settings.datasets.chat import (
# ChatDatasetSettings,
# ChatPromptTemplate,
# )
# from turbo_alignment.settings.datasets.pair_preference import (
# PairPreferenceDatasetSettings,
# )


@fixture(scope='session')
def tokenizer_llama2():
tokenizer_path = 'tests/fixtures/models/llama2_classification/tokenizer'
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
return tokenizer
# @fixture(scope='session')
# def tokenizer_llama2():
# tokenizer_path = 'tests/fixtures/models/llama2_classification/tokenizer'
# tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
# return tokenizer


@fixture(scope='session')
def tokenizer_gptj():
tokenizer_path = 'tests/fixtures/models/gptj_tiny_for_seq_cls'
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
return tokenizer
# @fixture(scope='session')
# def tokenizer_gptj():
# tokenizer_path = 'tests/fixtures/models/gptj_tiny_for_seq_cls'
# tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
# return tokenizer


@fixture(scope='session')
def chat_dataset_settings():
chat_dataset_settings = ChatDatasetSettings(
prompt_template=ChatPromptTemplate(
prefix_template='<S>',
suffix_template='</S>',
role_tag_mapping={'user': 'USER', 'bot': 'BOT', 'system': 'SYSTEM'},
),
max_tokens_count=8256,
)
return chat_dataset_settings
# @fixture(scope='session')
# def chat_dataset_settings():
# chat_dataset_settings = ChatDatasetSettings(
# prompt_template=ChatPromptTemplate(
# prefix_template='<S>',
# suffix_template='</S>',
# role_tag_mapping={'user': 'USER', 'bot': 'BOT', 'system': 'SYSTEM'},
# ),
# max_tokens_count=8256,
# )
# return chat_dataset_settings


@fixture(scope='session')
def classification_dataset_path() -> str:
return 'tests/fixtures/datasets/classification/train_classification.jsonl'
# @fixture(scope='session')
# def classification_dataset_path() -> str:
# return 'tests/fixtures/datasets/classification/train_classification.jsonl'


@fixture(scope='session')
def pair_preferences_dataset_path() -> str:
return 'tests/fixtures/datasets/rm/train_preferences.jsonl'
# @fixture(scope='session')
# def pair_preferences_dataset_path() -> str:
# return 'tests/fixtures/datasets/rm/train_preferences.jsonl'


@fixture(scope='session')
def kto_dataset_path() -> str:
return 'tests/fixtures/datasets/rm/train_kto.jsonl'
# @fixture(scope='session')
# def kto_dataset_path() -> str:
# return 'tests/fixtures/datasets/rm/train_kto.jsonl'


def load_dataset_source(dataset_path: str) -> tuple[DatasetSourceSettings, list[dict]]:
with open(dataset_path, 'r', encoding='utf-8') as f:
data_dicts = [json.loads(line) for line in f]
# def load_dataset_source(dataset_path: str) -> tuple[DatasetSourceSettings, list[dict]]:
# with open(dataset_path, 'r', encoding='utf-8') as f:
# data_dicts = [json.loads(line) for line in f]

source = DatasetSourceSettings(name='dataset_for_test', records_path=dataset_path, sample_rate=1)
# source = DatasetSourceSettings(name='dataset_for_test', records_path=dataset_path, sample_rate=1)

return source, data_dicts
# return source, data_dicts


@fixture(scope='session')
def classification_dataset_source(classification_dataset_path) -> tuple[DatasetSourceSettings, list[dict]]:
return load_dataset_source(classification_dataset_path)
# @fixture(scope='session')
# def classification_dataset_source(classification_dataset_path) -> tuple[DatasetSourceSettings, list[dict]]:
# return load_dataset_source(classification_dataset_path)


@fixture(scope='session')
def pair_preferences_dataset_source(pair_preferences_dataset_path) -> tuple[DatasetSourceSettings, list[dict]]:
return load_dataset_source(pair_preferences_dataset_path)
# @fixture(scope='session')
# def pair_preferences_dataset_source(pair_preferences_dataset_path) -> tuple[DatasetSourceSettings, list[dict]]:
# return load_dataset_source(pair_preferences_dataset_path)


@fixture(scope='session')
def kto_dataset_source(kto_dataset_path) -> tuple[DatasetSourceSettings, list[dict]]:
return load_dataset_source(kto_dataset_path)
# @fixture(scope='session')
# def kto_dataset_source(kto_dataset_path) -> tuple[DatasetSourceSettings, list[dict]]:
# return load_dataset_source(kto_dataset_path)


@fixture(scope='session')
def dpo_dataset(pair_preferences_dataset_source, tokenizer_llama2, chat_dataset_settings):
source, _ = pair_preferences_dataset_source
# @fixture(scope='session')
# def dpo_dataset(pair_preferences_dataset_source, tokenizer_llama2, chat_dataset_settings):
# source, _ = pair_preferences_dataset_source

dataset_cls = DatasetRegistry.by_name(DatasetType.PAIR_PREFERENCES).by_name(DatasetStrategy.TRAIN)
# dataset_cls = DatasetRegistry.by_name(DatasetType.PAIR_PREFERENCES).by_name(DatasetStrategy.TRAIN)

dataset_settings = PairPreferenceDatasetSettings(chat_settings=chat_dataset_settings)
dataset = dataset_cls(tokenizer=tokenizer_llama2, source=source, settings=dataset_settings)
# dataset_settings = PairPreferenceDatasetSettings(chat_settings=chat_dataset_settings)
# dataset = dataset_cls(tokenizer=tokenizer_llama2, source=source, settings=dataset_settings)

return dataset
# return dataset
2 changes: 1 addition & 1 deletion tests/fixtures/configs/train/classification/base.json
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@
"eos_token": "</s>",
"pad_token": "<pad>"
},
"trainer_settings": {
"training_arguments": {
"evaluation_strategy": "steps",
"per_device_train_batch_size": 1,
"per_device_eval_batch_size": 1,
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/configs/train/ddpo/base.json
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@
"eos_token": "</s>",
"pad_token": "<pad>"
},
"trainer_settings": {
"training_arguments": {
"evaluation_strategy": "steps",
"per_device_train_batch_size": 1,
"per_device_eval_batch_size": 1,
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/configs/train/dpo/base.json
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@
"eos_token": "</s>",
"pad_token": "<pad>"
},
"trainer_settings": {
"training_arguments": {
"evaluation_strategy": "steps",
"per_device_train_batch_size": 2,
"per_device_eval_batch_size": 2,
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/configs/train/dpo/simpo.json
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@
"eos_token": "</s>",
"pad_token": "<pad>"
},
"trainer_settings": {
"training_arguments": {
"evaluation_strategy": "steps",
"per_device_train_batch_size": 2,
"per_device_eval_batch_size": 2,
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/configs/train/kto/base.json
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
"eos_token": "</s>",
"pad_token": "<pad>"
},
"trainer_settings": {
"training_arguments": {
"evaluation_strategy": "steps",
"per_device_train_batch_size": 4,
"per_device_eval_batch_size": 4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@
"eos_token": "</s>",
"pad_token": "<pad>"
},
"trainer_settings": {
"training_arguments": {
"evaluation_strategy": "epoch",
"save_strategy": "epoch",
"per_device_train_batch_size": 1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@
"eos_token": "</s>",
"pad_token": "<pad>"
},
"trainer_settings": {
"training_arguments": {
"evaluation_strategy": "epoch",
"save_strategy": "epoch",
"per_device_train_batch_size": 1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@
"eos_token": "</s>",
"pad_token": "<pad>"
},
"trainer_settings": {
"training_arguments": {
"evaluation_strategy": "epoch",
"save_strategy": "epoch",
"per_device_train_batch_size": 1,
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/configs/train/rag/base.json
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@
"eos_token": "</s>",
"pad_token": "<pad>"
},
"trainer_settings": {
"training_arguments": {
"evaluation_strategy": "epoch",
"save_strategy": "epoch",
"per_device_train_batch_size": 1,
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/configs/train/rm/base.json
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
"eos_token": "</s>",
"pad_token": "<pad>"
},
"trainer_settings": {
"training_arguments": {
"evaluation_strategy": "steps",
"per_device_train_batch_size": 1,
"per_device_eval_batch_size": 1,
Expand Down
4 changes: 2 additions & 2 deletions tests/fixtures/configs/train/sft/base.json
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
"model_settings": {
"model_path": "tests/fixtures/models/llama2_tiny",
"model_type": "causal",
"generation_config": {
"transformers_settings": {
},
"peft_settings": {
"r": 8,
Expand Down Expand Up @@ -104,7 +104,7 @@
"eos_token": "</s>",
"pad_token": "<pad>"
},
"trainer_settings": {
"training_arguments": {
"evaluation_strategy": "steps",
"per_device_train_batch_size": 1,
"per_device_eval_batch_size": 1,
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/configs/train/sft/prompt_tuning.json
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@
"eos_token": "</s>",
"pad_token": "<pad>"
},
"trainer_settings": {
"training_arguments": {
"evaluation_strategy": "steps",
"per_device_train_batch_size": 1,
"per_device_eval_batch_size": 1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@
"eos_token": "</s>",
"pad_token": "<pad>"
},
"trainer_settings": {
"training_arguments": {
"evaluation_strategy": "steps",
"per_device_train_batch_size": 1,
"per_device_eval_batch_size": 1,
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/configs/train/sft/sft_with_rm_metric.json
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@
"eos_token": "</s>",
"pad_token": "<pad>"
},
"trainer_settings": {
"training_arguments": {
"evaluation_strategy": "steps",
"per_device_train_batch_size": 1,
"per_device_eval_batch_size": 1,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"trainer_settings": {
"training_arguments": {
"evaluation_strategy": "steps",
"per_device_train_batch_size": 1,
"per_device_eval_batch_size": 1,
Expand Down
Loading

0 comments on commit 00dc48e

Please sign in to comment.