-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Малахов Алексей Павлович
committed
Oct 17, 2024
1 parent
00dc48e
commit 49eff3d
Showing
26 changed files
with
400 additions
and
412 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,26 +1,24 @@ | ||
from pathlib import Path | ||
|
||
import pytest | ||
from typer.testing import CliRunner | ||
|
||
from tests.constants import FIXTURES_PATH | ||
from turbo_alignment.cli import app | ||
from turbo_alignment.settings.pipelines.train.classification import ( | ||
ClassificationTrainExperimentSettings, | ||
) | ||
|
||
runner = CliRunner() | ||
|
||
|
||
def test_classification_train(): | ||
result = runner.invoke( | ||
app, | ||
[ | ||
'train_classification', | ||
'--experiment_settings_path', | ||
FIXTURES_PATH / 'configs/train/classification/base.json', | ||
], | ||
catch_exceptions=False, | ||
) | ||
@pytest.mark.parametrize( | ||
'config_path', | ||
[ | ||
FIXTURES_PATH / 'configs/train/classification/base.json', | ||
], | ||
) | ||
def test_classification_train(config_path: Path): | ||
result = runner.invoke(app, ['train_classification', '--experiment_settings_path', str(config_path)]) | ||
assert result.exit_code == 0 | ||
assert Path('test_train_classification_output').is_dir() | ||
|
||
|
||
if __name__ == '__main__': | ||
test_classification_train() | ||
assert ClassificationTrainExperimentSettings.parse_file(config_path).log_path.is_dir() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,26 +1,26 @@ | ||
# from pathlib import Path | ||
from pathlib import Path | ||
|
||
# import pytest | ||
# from typer.testing import CliRunner | ||
import pytest | ||
from typer.testing import CliRunner | ||
|
||
# from tests.constants import FIXTURES_PATH | ||
# from turbo_alignment.cli import app | ||
# from turbo_alignment.settings.pipelines.train.dpo import DPOTrainExperimentSettings | ||
from tests.constants import FIXTURES_PATH | ||
from turbo_alignment.cli import app | ||
from turbo_alignment.settings.pipelines.train.dpo import DPOTrainExperimentSettings | ||
|
||
# runner = CliRunner() | ||
runner = CliRunner() | ||
|
||
|
||
# @pytest.mark.parametrize( | ||
# 'config_path', | ||
# [ | ||
# FIXTURES_PATH / 'configs/train/dpo/base.json', | ||
# FIXTURES_PATH / 'configs/train/dpo/simpo.json', | ||
# ], | ||
# ) | ||
# def test_dpo_train(config_path: Path): | ||
# result = runner.invoke( | ||
# app, | ||
# ['train_dpo', '--experiment_settings_path', str(config_path)], | ||
# ) | ||
# assert result.exit_code == 0 | ||
# assert DPOTrainExperimentSettings.parse_file(config_path).log_path.is_dir() | ||
@pytest.mark.parametrize( | ||
'config_path', | ||
[ | ||
FIXTURES_PATH / 'configs/train/dpo/base.json', | ||
FIXTURES_PATH / 'configs/train/dpo/simpo.json', | ||
], | ||
) | ||
def test_dpo_train(config_path: Path): | ||
result = runner.invoke( | ||
app, | ||
['train_dpo', '--experiment_settings_path', str(config_path)], | ||
) | ||
assert result.exit_code == 0 | ||
assert DPOTrainExperimentSettings.parse_file(config_path).log_path.is_dir() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,23 +1,23 @@ | ||
# from pathlib import Path | ||
from pathlib import Path | ||
|
||
# import pytest | ||
# from typer.testing import CliRunner | ||
import pytest | ||
from typer.testing import CliRunner | ||
|
||
# from tests.constants import FIXTURES_PATH | ||
# from turbo_alignment.cli import app | ||
# from turbo_alignment.settings.pipelines.train.kto import KTOTrainExperimentSettings | ||
from tests.constants import FIXTURES_PATH | ||
from turbo_alignment.cli import app | ||
from turbo_alignment.settings.pipelines.train.kto import KTOTrainExperimentSettings | ||
|
||
# runner = CliRunner() | ||
runner = CliRunner() | ||
|
||
|
||
# @pytest.mark.parametrize( | ||
# 'config_path', | ||
# [FIXTURES_PATH / 'configs/train/kto/base.json'], | ||
# ) | ||
# def test_dpo_train(config_path: Path): | ||
# result = runner.invoke( | ||
# app, | ||
# ['train_kto', '--experiment_settings_path', str(config_path)], | ||
# ) | ||
# assert result.exit_code == 0 | ||
# assert KTOTrainExperimentSettings.parse_file(config_path).log_path.is_dir() | ||
@pytest.mark.parametrize( | ||
'config_path', | ||
[FIXTURES_PATH / 'configs/train/kto/base.json'], | ||
) | ||
def test_dpo_train(config_path: Path): | ||
result = runner.invoke( | ||
app, | ||
['train_kto', '--experiment_settings_path', str(config_path)], | ||
) | ||
assert result.exit_code == 0 | ||
assert KTOTrainExperimentSettings.parse_file(config_path).log_path.is_dir() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,96 +1,96 @@ | ||
# import json | ||
import json | ||
|
||
# from pytest import fixture | ||
# from transformers import AutoTokenizer | ||
from pytest import fixture | ||
from transformers import AutoTokenizer | ||
|
||
# from turbo_alignment.dataset.registry import DatasetRegistry | ||
# from turbo_alignment.settings.datasets.base import ( | ||
# DatasetSourceSettings, | ||
# DatasetStrategy, | ||
# DatasetType, | ||
# ) | ||
# from turbo_alignment.settings.datasets.chat import ( | ||
# ChatDatasetSettings, | ||
# ChatPromptTemplate, | ||
# ) | ||
# from turbo_alignment.settings.datasets.pair_preference import ( | ||
# PairPreferenceDatasetSettings, | ||
# ) | ||
from turbo_alignment.dataset.registry import DatasetRegistry | ||
from turbo_alignment.settings.datasets.base import ( | ||
DatasetSourceSettings, | ||
DatasetStrategy, | ||
DatasetType, | ||
) | ||
from turbo_alignment.settings.datasets.chat import ( | ||
ChatDatasetSettings, | ||
ChatPromptTemplate, | ||
) | ||
from turbo_alignment.settings.datasets.pair_preference import ( | ||
PairPreferenceDatasetSettings, | ||
) | ||
|
||
|
||
# @fixture(scope='session') | ||
# def tokenizer_llama2(): | ||
# tokenizer_path = 'tests/fixtures/models/llama2_classification/tokenizer' | ||
# tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) | ||
# return tokenizer | ||
@fixture(scope='session') | ||
def tokenizer_llama2(): | ||
tokenizer_path = 'tests/fixtures/models/llama2_classification/tokenizer' | ||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) | ||
return tokenizer | ||
|
||
|
||
# @fixture(scope='session') | ||
# def tokenizer_gptj(): | ||
# tokenizer_path = 'tests/fixtures/models/gptj_tiny_for_seq_cls' | ||
# tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) | ||
# return tokenizer | ||
@fixture(scope='session') | ||
def tokenizer_gptj(): | ||
tokenizer_path = 'tests/fixtures/models/gptj_tiny_for_seq_cls' | ||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) | ||
return tokenizer | ||
|
||
|
||
# @fixture(scope='session') | ||
# def chat_dataset_settings(): | ||
# chat_dataset_settings = ChatDatasetSettings( | ||
# prompt_template=ChatPromptTemplate( | ||
# prefix_template='<S>', | ||
# suffix_template='</S>', | ||
# role_tag_mapping={'user': 'USER', 'bot': 'BOT', 'system': 'SYSTEM'}, | ||
# ), | ||
# max_tokens_count=8256, | ||
# ) | ||
# return chat_dataset_settings | ||
@fixture(scope='session') | ||
def chat_dataset_settings(): | ||
chat_dataset_settings = ChatDatasetSettings( | ||
prompt_template=ChatPromptTemplate( | ||
prefix_template='<S>', | ||
suffix_template='</S>', | ||
role_tag_mapping={'user': 'USER', 'bot': 'BOT', 'system': 'SYSTEM'}, | ||
), | ||
max_tokens_count=8256, | ||
) | ||
return chat_dataset_settings | ||
|
||
|
||
# @fixture(scope='session') | ||
# def classification_dataset_path() -> str: | ||
# return 'tests/fixtures/datasets/classification/train_classification.jsonl' | ||
@fixture(scope='session') | ||
def classification_dataset_path() -> str: | ||
return 'tests/fixtures/datasets/classification/train_classification.jsonl' | ||
|
||
|
||
# @fixture(scope='session') | ||
# def pair_preferences_dataset_path() -> str: | ||
# return 'tests/fixtures/datasets/rm/train_preferences.jsonl' | ||
@fixture(scope='session') | ||
def pair_preferences_dataset_path() -> str: | ||
return 'tests/fixtures/datasets/rm/train_preferences.jsonl' | ||
|
||
|
||
# @fixture(scope='session') | ||
# def kto_dataset_path() -> str: | ||
# return 'tests/fixtures/datasets/rm/train_kto.jsonl' | ||
@fixture(scope='session') | ||
def kto_dataset_path() -> str: | ||
return 'tests/fixtures/datasets/rm/train_kto.jsonl' | ||
|
||
|
||
# def load_dataset_source(dataset_path: str) -> tuple[DatasetSourceSettings, list[dict]]: | ||
# with open(dataset_path, 'r', encoding='utf-8') as f: | ||
# data_dicts = [json.loads(line) for line in f] | ||
def load_dataset_source(dataset_path: str) -> tuple[DatasetSourceSettings, list[dict]]: | ||
with open(dataset_path, 'r', encoding='utf-8') as f: | ||
data_dicts = [json.loads(line) for line in f] | ||
|
||
# source = DatasetSourceSettings(name='dataset_for_test', records_path=dataset_path, sample_rate=1) | ||
source = DatasetSourceSettings(name='dataset_for_test', records_path=dataset_path, sample_rate=1) | ||
|
||
# return source, data_dicts | ||
return source, data_dicts | ||
|
||
|
||
# @fixture(scope='session') | ||
# def classification_dataset_source(classification_dataset_path) -> tuple[DatasetSourceSettings, list[dict]]: | ||
# return load_dataset_source(classification_dataset_path) | ||
@fixture(scope='session') | ||
def classification_dataset_source(classification_dataset_path) -> tuple[DatasetSourceSettings, list[dict]]: | ||
return load_dataset_source(classification_dataset_path) | ||
|
||
|
||
# @fixture(scope='session') | ||
# def pair_preferences_dataset_source(pair_preferences_dataset_path) -> tuple[DatasetSourceSettings, list[dict]]: | ||
# return load_dataset_source(pair_preferences_dataset_path) | ||
@fixture(scope='session') | ||
def pair_preferences_dataset_source(pair_preferences_dataset_path) -> tuple[DatasetSourceSettings, list[dict]]: | ||
return load_dataset_source(pair_preferences_dataset_path) | ||
|
||
|
||
# @fixture(scope='session') | ||
# def kto_dataset_source(kto_dataset_path) -> tuple[DatasetSourceSettings, list[dict]]: | ||
# return load_dataset_source(kto_dataset_path) | ||
@fixture(scope='session') | ||
def kto_dataset_source(kto_dataset_path) -> tuple[DatasetSourceSettings, list[dict]]: | ||
return load_dataset_source(kto_dataset_path) | ||
|
||
|
||
# @fixture(scope='session') | ||
# def dpo_dataset(pair_preferences_dataset_source, tokenizer_llama2, chat_dataset_settings): | ||
# source, _ = pair_preferences_dataset_source | ||
@fixture(scope='session') | ||
def dpo_dataset(pair_preferences_dataset_source, tokenizer_llama2, chat_dataset_settings): | ||
source, _ = pair_preferences_dataset_source | ||
|
||
# dataset_cls = DatasetRegistry.by_name(DatasetType.PAIR_PREFERENCES).by_name(DatasetStrategy.TRAIN) | ||
dataset_cls = DatasetRegistry.by_name(DatasetType.PAIR_PREFERENCES).by_name(DatasetStrategy.TRAIN) | ||
|
||
# dataset_settings = PairPreferenceDatasetSettings(chat_settings=chat_dataset_settings) | ||
# dataset = dataset_cls(tokenizer=tokenizer_llama2, source=source, settings=dataset_settings) | ||
dataset_settings = PairPreferenceDatasetSettings(chat_settings=chat_dataset_settings) | ||
dataset = dataset_cls(tokenizer=tokenizer_llama2, source=source, settings=dataset_settings) | ||
|
||
# return dataset | ||
return dataset |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.