From 1d9d2373525e1ebf305b8ff19cf44f9e70e65473 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 8 Jan 2025 23:54:20 -0500 Subject: [PATCH] add helper to verify the correct model output file exists --- tests/e2e/integrations/liger.py | 7 +++---- .../integrations/test_cut_cross_entropy.py | 7 +++---- tests/e2e/patched/test_4d_multipack_llama.py | 7 +++---- tests/e2e/patched/test_fa_xentropy.py | 5 ++--- tests/e2e/patched/test_falcon_samplepack.py | 7 +++---- tests/e2e/patched/test_llama_s2_attention.py | 7 +++---- .../e2e/patched/test_lora_llama_multipack.py | 7 +++---- tests/e2e/patched/test_mistral_samplepack.py | 7 +++---- tests/e2e/patched/test_mixtral_samplepack.py | 7 +++---- tests/e2e/patched/test_phi_multipack.py | 4 ++-- tests/e2e/patched/test_resume.py | 5 ++--- tests/e2e/patched/test_unsloth_qlora.py | 9 ++++---- tests/e2e/test_embeddings_lr.py | 7 +++---- tests/e2e/test_falcon.py | 6 +++--- tests/e2e/test_llama.py | 9 ++++---- tests/e2e/test_llama_pretrain.py | 5 ++--- tests/e2e/test_llama_vision.py | 7 +++---- tests/e2e/test_lora_llama.py | 5 ++--- tests/e2e/test_mistral.py | 4 ++-- tests/e2e/test_mixtral.py | 13 ++++++------ tests/e2e/test_optimizers.py | 9 ++++---- tests/e2e/test_phi.py | 4 ++-- tests/e2e/test_reward_model_llama.py | 5 ++--- tests/e2e/utils.py | 21 +++++++++++++++++++ 24 files changed, 89 insertions(+), 85 deletions(-) diff --git a/tests/e2e/integrations/liger.py b/tests/e2e/integrations/liger.py index 455c3d2818..fa3b23b64e 100644 --- a/tests/e2e/integrations/liger.py +++ b/tests/e2e/integrations/liger.py @@ -2,7 +2,6 @@ Simple end-to-end test for Liger integration """ import unittest -from pathlib import Path from axolotl.cli import load_datasets from axolotl.common.cli import TrainerCliArgs @@ -10,7 +9,7 @@ from axolotl.utils.config import normalize_config, prepare_plugins from axolotl.utils.dict import DictDefault -from ..utils import with_temp_dir +from ..utils import check_model_output_exists, with_temp_dir class LigerIntegrationTestCase(unittest.TestCase): @@ -63,7 +62,7 @@ def test_llama_wo_flce(self, temp_dir): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "model.safetensors").exists() + check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_llama_w_flce(self, temp_dir): @@ -110,4 +109,4 @@ def test_llama_w_flce(self, temp_dir): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "model.safetensors").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/integrations/test_cut_cross_entropy.py b/tests/e2e/integrations/test_cut_cross_entropy.py index a74813e3a0..e047274c09 100644 --- a/tests/e2e/integrations/test_cut_cross_entropy.py +++ b/tests/e2e/integrations/test_cut_cross_entropy.py @@ -2,9 +2,8 @@ Simple end-to-end test for Cut Cross Entropy integration """ -from pathlib import Path - import pytest +from e2e.utils import check_model_output_exists from axolotl.cli import load_datasets from axolotl.common.cli import TrainerCliArgs @@ -67,7 +66,7 @@ def test_llama_w_cce(self, min_cfg, temp_dir): train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) else: train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "model.safetensors").exists() + check_model_output_exists(temp_dir, cfg) @pytest.mark.parametrize( "attention_type", @@ -95,4 +94,4 @@ def test_llama_w_cce_and_attention(self, min_cfg, temp_dir, attention_type): train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) else: train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "model.safetensors").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_4d_multipack_llama.py b/tests/e2e/patched/test_4d_multipack_llama.py index b0ada92304..08b3bf0daf 100644 --- a/tests/e2e/patched/test_4d_multipack_llama.py +++ b/tests/e2e/patched/test_4d_multipack_llama.py @@ -5,7 +5,6 @@ import logging import os import unittest -from pathlib import Path from axolotl.cli import load_datasets from axolotl.common.cli import TrainerCliArgs @@ -13,7 +12,7 @@ from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from ..utils import require_torch_2_3_1, with_temp_dir +from ..utils import check_model_output_exists, require_torch_2_3_1, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -67,7 +66,7 @@ def test_sdp_lora_packing(self, temp_dir): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_torch_lora_packing(self, temp_dir): @@ -111,4 +110,4 @@ def test_torch_lora_packing(self, temp_dir): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_fa_xentropy.py b/tests/e2e/patched/test_fa_xentropy.py index 183843b7b1..791d955b28 100644 --- a/tests/e2e/patched/test_fa_xentropy.py +++ b/tests/e2e/patched/test_fa_xentropy.py @@ -4,7 +4,6 @@ import logging import os -from pathlib import Path import pytest from transformers.utils import is_torch_bf16_gpu_available @@ -15,7 +14,7 @@ from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from ..utils import check_tensorboard +from ..utils import check_model_output_exists, check_tensorboard LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -82,7 +81,7 @@ def test_lora_packing_fa_cross_entropy(self, temp_dir, gradient_accumulation_ste dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) check_tensorboard( temp_dir + "/runs", "train/train_loss", 1.5, "Train Loss is too high" diff --git a/tests/e2e/patched/test_falcon_samplepack.py b/tests/e2e/patched/test_falcon_samplepack.py index d9d7151032..69516810f6 100644 --- a/tests/e2e/patched/test_falcon_samplepack.py +++ b/tests/e2e/patched/test_falcon_samplepack.py @@ -5,7 +5,6 @@ import logging import os import unittest -from pathlib import Path from axolotl.cli import load_datasets from axolotl.common.cli import TrainerCliArgs @@ -13,7 +12,7 @@ from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from ..utils import with_temp_dir +from ..utils import check_model_output_exists, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -69,7 +68,7 @@ def test_qlora(self, temp_dir): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_ft(self, temp_dir): @@ -109,4 +108,4 @@ def test_ft(self, temp_dir): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "pytorch_model.bin").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_llama_s2_attention.py b/tests/e2e/patched/test_llama_s2_attention.py index 0f2539daf8..d0fdd918a1 100644 --- a/tests/e2e/patched/test_llama_s2_attention.py +++ b/tests/e2e/patched/test_llama_s2_attention.py @@ -5,7 +5,6 @@ import logging import os import unittest -from pathlib import Path import pytest @@ -15,7 +14,7 @@ from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from ..utils import with_temp_dir +from ..utils import check_model_output_exists, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -71,7 +70,7 @@ def test_lora_s2_attn(self, temp_dir): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_fft_s2_attn(self, temp_dir): @@ -111,4 +110,4 @@ def test_fft_s2_attn(self, temp_dir): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "pytorch_model.bin").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_lora_llama_multipack.py b/tests/e2e/patched/test_lora_llama_multipack.py index be2f133fb0..634e544d20 100644 --- a/tests/e2e/patched/test_lora_llama_multipack.py +++ b/tests/e2e/patched/test_lora_llama_multipack.py @@ -5,7 +5,6 @@ import logging import os import unittest -from pathlib import Path import pytest from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_available @@ -16,7 +15,7 @@ from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from ..utils import with_temp_dir +from ..utils import check_model_output_exists, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -76,7 +75,7 @@ def test_lora_packing(self, temp_dir): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) @pytest.mark.skipif(not is_auto_gptq_available(), reason="auto-gptq not available") @with_temp_dir @@ -126,4 +125,4 @@ def test_lora_gptq_packed(self, temp_dir): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_mistral_samplepack.py b/tests/e2e/patched/test_mistral_samplepack.py index 6685fb9d57..e93863e09c 100644 --- a/tests/e2e/patched/test_mistral_samplepack.py +++ b/tests/e2e/patched/test_mistral_samplepack.py @@ -5,7 +5,6 @@ import logging import os import unittest -from pathlib import Path from axolotl.cli import load_datasets from axolotl.common.cli import TrainerCliArgs @@ -13,7 +12,7 @@ from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from ..utils import with_temp_dir +from ..utils import check_model_output_exists, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -69,7 +68,7 @@ def test_lora_packing(self, temp_dir): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_ft_packing(self, temp_dir): @@ -110,4 +109,4 @@ def test_ft_packing(self, temp_dir): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "pytorch_model.bin").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_mixtral_samplepack.py b/tests/e2e/patched/test_mixtral_samplepack.py index 684baaaff8..f87c34fd10 100644 --- a/tests/e2e/patched/test_mixtral_samplepack.py +++ b/tests/e2e/patched/test_mixtral_samplepack.py @@ -5,7 +5,6 @@ import logging import os import unittest -from pathlib import Path from axolotl.cli import load_datasets from axolotl.common.cli import TrainerCliArgs @@ -13,7 +12,7 @@ from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from ..utils import with_temp_dir +from ..utils import check_model_output_exists, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -66,7 +65,7 @@ def test_qlora(self, temp_dir): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_ft(self, temp_dir): @@ -108,4 +107,4 @@ def test_ft(self, temp_dir): "MixtralFlashAttention2" in model.model.layers[0].self_attn.__class__.__name__ ) - assert (Path(temp_dir) / "pytorch_model.bin").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_phi_multipack.py b/tests/e2e/patched/test_phi_multipack.py index 7b5bf92dfa..21064a9ff9 100644 --- a/tests/e2e/patched/test_phi_multipack.py +++ b/tests/e2e/patched/test_phi_multipack.py @@ -13,7 +13,7 @@ from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from ..utils import with_temp_dir +from ..utils import check_model_output_exists, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -120,4 +120,4 @@ def test_qlora_packed(self, temp_dir): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_resume.py b/tests/e2e/patched/test_resume.py index 7d82ea8c37..5639d2eaee 100644 --- a/tests/e2e/patched/test_resume.py +++ b/tests/e2e/patched/test_resume.py @@ -6,7 +6,6 @@ import os import re import subprocess -from pathlib import Path from transformers.utils import is_torch_bf16_gpu_available @@ -16,7 +15,7 @@ from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from ..utils import most_recent_subdir +from ..utils import check_model_output_exists, most_recent_subdir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -83,7 +82,7 @@ def test_resume_lora_packed(self, temp_dir): cli_args = TrainerCliArgs() train(cfg=resume_cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) tb_log_path_1 = most_recent_subdir(temp_dir + "/runs") cmd = f"tensorboard --inspect --logdir {tb_log_path_1}" diff --git a/tests/e2e/patched/test_unsloth_qlora.py b/tests/e2e/patched/test_unsloth_qlora.py index b58406185a..0c4095ba87 100644 --- a/tests/e2e/patched/test_unsloth_qlora.py +++ b/tests/e2e/patched/test_unsloth_qlora.py @@ -3,7 +3,6 @@ """ import logging import os -from pathlib import Path import pytest @@ -13,7 +12,7 @@ from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from ..utils import check_tensorboard +from ..utils import check_model_output_exists, check_tensorboard LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -74,7 +73,7 @@ def test_unsloth_llama_qlora_fa2(self, temp_dir, sample_packing): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) check_tensorboard( temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high" @@ -124,7 +123,7 @@ def test_unsloth_llama_qlora_unpacked(self, temp_dir): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) check_tensorboard( temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high" @@ -179,7 +178,7 @@ def test_unsloth_llama_qlora_unpacked_no_fa2_fp16(self, temp_dir, sdp_attention) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) check_tensorboard( temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high" diff --git a/tests/e2e/test_embeddings_lr.py b/tests/e2e/test_embeddings_lr.py index 6e5ebd05f7..222d620ae7 100644 --- a/tests/e2e/test_embeddings_lr.py +++ b/tests/e2e/test_embeddings_lr.py @@ -5,7 +5,6 @@ import logging import os import unittest -from pathlib import Path from axolotl.cli import load_datasets from axolotl.common.cli import TrainerCliArgs @@ -13,7 +12,7 @@ from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from .utils import check_tensorboard, with_temp_dir +from .utils import check_model_output_exists, check_tensorboard, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -62,7 +61,7 @@ def test_train_w_embedding_lr_scale(self, temp_dir): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "model.safetensors").exists() + check_model_output_exists(temp_dir, cfg) check_tensorboard( temp_dir + "/runs", "train/train_loss", 2.0, "Loss is too high" @@ -106,7 +105,7 @@ def test_train_w_embedding_lr(self, temp_dir): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "model.safetensors").exists() + check_model_output_exists(temp_dir, cfg) check_tensorboard( temp_dir + "/runs", "train/train_loss", 2.0, "Loss is too high" diff --git a/tests/e2e/test_falcon.py b/tests/e2e/test_falcon.py index c76699a7c8..99600835f3 100644 --- a/tests/e2e/test_falcon.py +++ b/tests/e2e/test_falcon.py @@ -13,7 +13,7 @@ from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from .utils import with_temp_dir +from .utils import check_model_output_exists, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -71,7 +71,7 @@ def test_lora(self, temp_dir): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_lora_added_vocab(self, temp_dir): @@ -124,7 +124,7 @@ def test_lora_added_vocab(self, temp_dir): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_ft(self, temp_dir): diff --git a/tests/e2e/test_llama.py b/tests/e2e/test_llama.py index 1ce9d60b98..4384bb61e7 100644 --- a/tests/e2e/test_llama.py +++ b/tests/e2e/test_llama.py @@ -4,7 +4,8 @@ import logging import os -from pathlib import Path + +from e2e.utils import check_model_output_exists from axolotl.cli import load_datasets from axolotl.common.cli import TrainerCliArgs @@ -60,7 +61,7 @@ def test_fft_trust_remote_code(self, temp_dir): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "model.safetensors").exists() + check_model_output_exists(temp_dir, cfg) def test_fix_untrained_tokens(self, temp_dir): # pylint: disable=duplicate-code @@ -103,7 +104,7 @@ def test_fix_untrained_tokens(self, temp_dir): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "model.safetensors").exists() + check_model_output_exists(temp_dir, cfg) def test_batch_flattening(self, temp_dir): # pylint: disable=duplicate-code @@ -142,4 +143,4 @@ def test_batch_flattening(self, temp_dir): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "model.safetensors").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_llama_pretrain.py b/tests/e2e/test_llama_pretrain.py index 62fb63c471..d13b10659a 100644 --- a/tests/e2e/test_llama_pretrain.py +++ b/tests/e2e/test_llama_pretrain.py @@ -5,7 +5,6 @@ import logging import os import unittest -from pathlib import Path from axolotl.cli import load_datasets from axolotl.common.cli import TrainerCliArgs @@ -13,7 +12,7 @@ from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from .utils import with_temp_dir +from .utils import check_model_output_exists, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -64,4 +63,4 @@ def test_pretrain_w_sample_packing(self, temp_dir): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "model.safetensors").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_llama_vision.py b/tests/e2e/test_llama_vision.py index 1d583a3267..250cf418c0 100644 --- a/tests/e2e/test_llama_vision.py +++ b/tests/e2e/test_llama_vision.py @@ -5,7 +5,6 @@ import logging import os import unittest -from pathlib import Path from axolotl.cli import load_datasets from axolotl.common.cli import TrainerCliArgs @@ -13,7 +12,7 @@ from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from .utils import with_temp_dir +from .utils import check_model_output_exists, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -68,7 +67,7 @@ def test_lora_llama_vision_text_only_dataset(self, temp_dir): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.safetensors").exists() + check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_lora_llama_vision_multimodal_dataset(self, temp_dir): @@ -113,4 +112,4 @@ def test_lora_llama_vision_multimodal_dataset(self, temp_dir): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.safetensors").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_lora_llama.py b/tests/e2e/test_lora_llama.py index d06be60b96..a7ead64a53 100644 --- a/tests/e2e/test_lora_llama.py +++ b/tests/e2e/test_lora_llama.py @@ -5,7 +5,6 @@ import logging import os import unittest -from pathlib import Path from axolotl.cli import load_datasets from axolotl.common.cli import TrainerCliArgs @@ -13,7 +12,7 @@ from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from .utils import with_temp_dir +from .utils import check_model_output_exists, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -65,4 +64,4 @@ def test_lora(self, temp_dir): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_mistral.py b/tests/e2e/test_mistral.py index 57d85e51eb..d69d56dc37 100644 --- a/tests/e2e/test_mistral.py +++ b/tests/e2e/test_mistral.py @@ -15,7 +15,7 @@ from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from .utils import with_temp_dir +from .utils import check_model_output_exists, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -69,7 +69,7 @@ def test_lora(self, temp_dir): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_ft(self, temp_dir): diff --git a/tests/e2e/test_mixtral.py b/tests/e2e/test_mixtral.py index d4dad14ef2..6792d05a67 100644 --- a/tests/e2e/test_mixtral.py +++ b/tests/e2e/test_mixtral.py @@ -5,7 +5,6 @@ import logging import os import unittest -from pathlib import Path import torch from transformers.utils import is_torch_bf16_gpu_available @@ -16,7 +15,7 @@ from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from .utils import with_temp_dir +from .utils import check_model_output_exists, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -79,7 +78,7 @@ def test_qlora_w_fa2(self, temp_dir): model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype == torch.float32 ) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_qlora_wo_fa2(self, temp_dir): @@ -133,7 +132,7 @@ def test_qlora_wo_fa2(self, temp_dir): model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype == torch.float32 ) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_16bit_lora_w_fa2(self, temp_dir): @@ -190,7 +189,7 @@ def test_16bit_lora_w_fa2(self, temp_dir): model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype == torch.float32 ) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_16bit_lora_wo_fa2(self, temp_dir): @@ -247,7 +246,7 @@ def test_16bit_lora_wo_fa2(self, temp_dir): model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype == torch.float32 ) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_ft(self, temp_dir): @@ -287,4 +286,4 @@ def test_ft(self, temp_dir): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "pytorch_model.bin").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_optimizers.py b/tests/e2e/test_optimizers.py index 2317bfb97a..de6156837c 100644 --- a/tests/e2e/test_optimizers.py +++ b/tests/e2e/test_optimizers.py @@ -5,7 +5,6 @@ import logging import os import unittest -from pathlib import Path from axolotl.cli import load_datasets from axolotl.common.cli import TrainerCliArgs @@ -13,7 +12,7 @@ from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from .utils import require_torch_2_5_1, with_temp_dir +from .utils import check_model_output_exists, require_torch_2_5_1, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -65,7 +64,7 @@ def test_optimi_adamw(self, temp_dir): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) @with_temp_dir @require_torch_2_5_1 @@ -109,7 +108,7 @@ def test_adopt_adamw(self, temp_dir): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_fft_schedule_free_adamw(self, temp_dir): @@ -144,4 +143,4 @@ def test_fft_schedule_free_adamw(self, temp_dir): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "model.safetensors").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_phi.py b/tests/e2e/test_phi.py index 4cc6bcdcc9..4deea63353 100644 --- a/tests/e2e/test_phi.py +++ b/tests/e2e/test_phi.py @@ -13,7 +13,7 @@ from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from .utils import with_temp_dir +from .utils import check_model_output_exists, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -116,4 +116,4 @@ def test_phi_qlora(self, temp_dir): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_reward_model_llama.py b/tests/e2e/test_reward_model_llama.py index 27ac3e25f1..c4cb705ea8 100644 --- a/tests/e2e/test_reward_model_llama.py +++ b/tests/e2e/test_reward_model_llama.py @@ -5,7 +5,6 @@ import logging import os import unittest -from pathlib import Path from axolotl.cli import load_datasets from axolotl.common.cli import TrainerCliArgs @@ -13,7 +12,7 @@ from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from .utils import with_temp_dir +from .utils import check_model_output_exists, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -71,4 +70,4 @@ def test_rm_fft(self, temp_dir): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index de5b599a13..9ec2f7f91f 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -14,6 +14,8 @@ from packaging import version from tbparse import SummaryReader +from axolotl.utils.dict import DictDefault + def with_temp_dir(test_func): @wraps(test_func) @@ -81,3 +83,22 @@ def check_tensorboard( df = reader.scalars # pylint: disable=invalid-name df = df[(df.tag == tag)] # pylint: disable=invalid-name assert df.value.values[-1] < lt_val, assertion_err + + +def check_model_output_exists(temp_dir: str, cfg: DictDefault) -> None: + """ + helper function to check if a model output file exists after training + + checks based on adapter or not and if safetensors saves are enabled or not + """ + + if cfg.save_safetensors: + if not cfg.adapter: + assert (Path(temp_dir) / "model.safetensors").exists() + else: + assert (Path(temp_dir) / "adapter_model.safetensors").exists() + else: + if not cfg.adapter: + assert (Path(temp_dir) / "pytorch_model.bin").exists() + else: + assert (Path(temp_dir) / "adapter_model.bin").exists()