From fb0752a5e409a3462055099204eb584ff53df8f2 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 17 Dec 2024 13:55:45 -0500 Subject: [PATCH] add e2e for flattening --- tests/e2e/test_llama.py | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/tests/e2e/test_llama.py b/tests/e2e/test_llama.py index 33d12157a..c7d000744 100644 --- a/tests/e2e/test_llama.py +++ b/tests/e2e/test_llama.py @@ -104,3 +104,43 @@ def test_fix_untrained_tokens(self, temp_dir): train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) assert (Path(temp_dir) / "model.safetensors").exists() + + + def test_batch_flattening(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "trust_remote_code": True, + "sequence_len": 512, + "val_set_size": 0.01, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "max_steps": 5, + "micro_batch_size": 4, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_8bit", + "lr_scheduler": "cosine", + "flash_attention": True, + "sample_packing": False, + "batch_flattening": True, + "bf16": True, + "save_safetensors": True, + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "model.safetensors").exists()