-
Notifications
You must be signed in to change notification settings - Fork 537
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add fixtures for testing boilerplate, tiny mpt models, and tiny finetune dataset
- Loading branch information
Showing
14 changed files
with
272 additions
and
243 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# Copyright 2022 MosaicML LLM Foundry authors | ||
# SPDX-License-Identifier: Apache-2.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# Copyright 2022 MosaicML LLM Foundry authors | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import gc | ||
|
||
import pytest | ||
import torch | ||
from composer.utils import dist, get_device, reproducibility | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def initialize_dist(request: pytest.FixtureRequest): | ||
"""Initialize the default PyTorch distributed process group for tests.""" | ||
# should we just always initialize dist like in train.py? | ||
_default = pytest.mark.world_size(1).mark | ||
world_size = request.node.get_closest_marker('world_size', _default).args[0] | ||
gpu = request.node.get_closest_marker('gpu') | ||
if world_size > 1: | ||
dist.initialize_dist(get_device('gpu' if gpu is not None else 'cpu')) | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def clear_cuda_cache(request: pytest.FixtureRequest): | ||
"""Clear memory between GPU tests.""" | ||
marker = request.node.get_closest_marker('gpu') | ||
if marker is not None and torch.cuda.is_available(): | ||
torch.cuda.empty_cache() | ||
gc.collect() # Only gc on GPU tests as it 2x slows down CPU tests | ||
|
||
|
||
@pytest.fixture | ||
def random_seed() -> int: | ||
return 17 | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def seed_all(random_seed: int): | ||
"""Sets the seed for reproducibility.""" | ||
reproducibility.seed_all(random_seed) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# Copyright 2022 MosaicML LLM Foundry authors | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from pathlib import Path | ||
|
||
from composer.utils import dist | ||
from omegaconf import DictConfig | ||
from pytest import fixture | ||
from torch.utils.data import DataLoader | ||
from transformers import PreTrainedTokenizerBase | ||
|
||
from llmfoundry.data.finetuning.dataloader import build_finetuning_dataloader | ||
from tests.data_utils import make_tiny_ft_dataset | ||
|
||
|
||
@fixture | ||
def tiny_ft_dataset_path(tmp_path: Path, dataset_size: int = 4) -> Path: | ||
"""Creates a tiny dataset and returns the path.""" | ||
tiny_dataset_path = tmp_path / 'test-ift-data-small' | ||
tiny_dataset_path.mkdir(exist_ok=True) | ||
tiny_dataset_file = tiny_dataset_path / 'train.jsonl' | ||
if dist.get_world_size() == 1 or dist.get_global_rank() == 0: | ||
make_tiny_ft_dataset(path=str(tiny_dataset_file), size=dataset_size) | ||
return tiny_dataset_path | ||
|
||
|
||
@fixture | ||
def tiny_ft_dataloader(tiny_ft_dataset_path: Path, | ||
mpt_tokenizer: PreTrainedTokenizerBase, | ||
max_seq_len: int = 128, | ||
device_batch_size: int = 1) -> DataLoader: | ||
dataloader_cfg = DictConfig({ | ||
'name': 'finetuning', | ||
'dataset': { | ||
'hf_name': str(tiny_ft_dataset_path), | ||
'split': 'train', | ||
'max_seq_len': max_seq_len, | ||
'decoder_only_format': True, | ||
'allow_pad_trimming': False, | ||
'packing_ratio': None, | ||
'shuffle': True, | ||
}, | ||
'drop_last': False, | ||
'num_workers': 4, | ||
'pin_memory': False, | ||
'prefetch_factor': 2, | ||
'persistent_workers': False, | ||
'timeout': 0 | ||
}) | ||
|
||
dataloader = build_finetuning_dataloader( | ||
dataloader_cfg, | ||
mpt_tokenizer, | ||
device_batch_size, | ||
).dataloader | ||
|
||
assert isinstance(dataloader, DataLoader) | ||
return dataloader |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
# Copyright 2022 MosaicML LLM Foundry authors | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from typing import Any, Callable | ||
|
||
from omegaconf import DictConfig | ||
from pytest import fixture | ||
from transformers import PreTrainedTokenizerBase | ||
|
||
from llmfoundry.models.hf.hf_causal_lm import ComposerHFCausalLM | ||
from llmfoundry.models.model_registry import COMPOSER_MODEL_REGISTRY | ||
from llmfoundry.models.mpt.modeling_mpt import ComposerMPTCausalLM | ||
from llmfoundry.utils.builders import build_tokenizer | ||
|
||
|
||
def _build_model(config: DictConfig, tokenizer: PreTrainedTokenizerBase): | ||
model = COMPOSER_MODEL_REGISTRY[config.name](config, tokenizer) | ||
return model | ||
|
||
|
||
@fixture | ||
def mpt_tokenizer(): | ||
return build_tokenizer('EleutherAI/gpt-neox-20b', {}) | ||
|
||
|
||
@fixture | ||
def build_tiny_mpt( | ||
mpt_tokenizer: PreTrainedTokenizerBase | ||
) -> Callable[..., ComposerMPTCausalLM]: | ||
|
||
def build(**kwargs: Any) -> ComposerMPTCausalLM: | ||
config = DictConfig({ | ||
'name': 'mpt_causal_lm', | ||
'd_model': 128, | ||
'n_heads': 4, | ||
'n_layers': 2, | ||
'expansion_ratio': 2, | ||
}) | ||
config.update(kwargs) | ||
model = _build_model(config, mpt_tokenizer) | ||
assert isinstance(model, ComposerMPTCausalLM) | ||
return model | ||
|
||
return build | ||
|
||
|
||
@fixture | ||
def build_tiny_hf_mpt( | ||
mpt_tokenizer: PreTrainedTokenizerBase | ||
) -> Callable[..., ComposerHFCausalLM]: | ||
|
||
def build(**kwargs: Any) -> ComposerHFCausalLM: | ||
config_overrides = { | ||
'd_model': 128, | ||
'n_heads': 4, | ||
'n_layers': 2, | ||
'expansion_ratio': 2, | ||
} | ||
config_overrides.update(kwargs) | ||
config = DictConfig({ | ||
'name': 'hf_causal_lm', | ||
'pretrained_model_name_or_path': 'mosaicml/mpt-7b', | ||
'pretrained': False, | ||
'config_overrides': config_overrides, | ||
}) | ||
model = _build_model(config, mpt_tokenizer) | ||
assert isinstance(model, ComposerHFCausalLM) | ||
return model | ||
|
||
return build |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.