Skip to content

Commit

Permalink
add helper to verify the correct model output file exists
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Jan 9, 2025
1 parent c1b920f commit 1d9d237
Show file tree
Hide file tree
Showing 24 changed files with 89 additions and 85 deletions.
7 changes: 3 additions & 4 deletions tests/e2e/integrations/liger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
Simple end-to-end test for Liger integration
"""
import unittest
from pathlib import Path

from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.utils.dict import DictDefault

from ..utils import with_temp_dir
from ..utils import check_model_output_exists, with_temp_dir


class LigerIntegrationTestCase(unittest.TestCase):
Expand Down Expand Up @@ -63,7 +62,7 @@ def test_llama_wo_flce(self, temp_dir):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
check_model_output_exists(temp_dir, cfg)

@with_temp_dir
def test_llama_w_flce(self, temp_dir):
Expand Down Expand Up @@ -110,4 +109,4 @@ def test_llama_w_flce(self, temp_dir):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
check_model_output_exists(temp_dir, cfg)
7 changes: 3 additions & 4 deletions tests/e2e/integrations/test_cut_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
Simple end-to-end test for Cut Cross Entropy integration
"""

from pathlib import Path

import pytest
from e2e.utils import check_model_output_exists

from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
Expand Down Expand Up @@ -67,7 +66,7 @@ def test_llama_w_cce(self, min_cfg, temp_dir):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
else:
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
check_model_output_exists(temp_dir, cfg)

@pytest.mark.parametrize(
"attention_type",
Expand Down Expand Up @@ -95,4 +94,4 @@ def test_llama_w_cce_and_attention(self, min_cfg, temp_dir, attention_type):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
else:
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
check_model_output_exists(temp_dir, cfg)
7 changes: 3 additions & 4 deletions tests/e2e/patched/test_4d_multipack_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@
import logging
import os
import unittest
from pathlib import Path

from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault

from ..utils import require_torch_2_3_1, with_temp_dir
from ..utils import check_model_output_exists, require_torch_2_3_1, with_temp_dir

LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
Expand Down Expand Up @@ -67,7 +66,7 @@ def test_sdp_lora_packing(self, temp_dir):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_model_output_exists(temp_dir, cfg)

@with_temp_dir
def test_torch_lora_packing(self, temp_dir):
Expand Down Expand Up @@ -111,4 +110,4 @@ def test_torch_lora_packing(self, temp_dir):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_model_output_exists(temp_dir, cfg)
5 changes: 2 additions & 3 deletions tests/e2e/patched/test_fa_xentropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import logging
import os
from pathlib import Path

import pytest
from transformers.utils import is_torch_bf16_gpu_available
Expand All @@ -15,7 +14,7 @@
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault

from ..utils import check_tensorboard
from ..utils import check_model_output_exists, check_tensorboard

LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
Expand Down Expand Up @@ -82,7 +81,7 @@ def test_lora_packing_fa_cross_entropy(self, temp_dir, gradient_accumulation_ste
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_model_output_exists(temp_dir, cfg)

check_tensorboard(
temp_dir + "/runs", "train/train_loss", 1.5, "Train Loss is too high"
Expand Down
7 changes: 3 additions & 4 deletions tests/e2e/patched/test_falcon_samplepack.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@
import logging
import os
import unittest
from pathlib import Path

from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault

from ..utils import with_temp_dir
from ..utils import check_model_output_exists, with_temp_dir

LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
Expand Down Expand Up @@ -69,7 +68,7 @@ def test_qlora(self, temp_dir):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_model_output_exists(temp_dir, cfg)

@with_temp_dir
def test_ft(self, temp_dir):
Expand Down Expand Up @@ -109,4 +108,4 @@ def test_ft(self, temp_dir):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "pytorch_model.bin").exists()
check_model_output_exists(temp_dir, cfg)
7 changes: 3 additions & 4 deletions tests/e2e/patched/test_llama_s2_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import logging
import os
import unittest
from pathlib import Path

import pytest

Expand All @@ -15,7 +14,7 @@
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault

from ..utils import with_temp_dir
from ..utils import check_model_output_exists, with_temp_dir

LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
Expand Down Expand Up @@ -71,7 +70,7 @@ def test_lora_s2_attn(self, temp_dir):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_model_output_exists(temp_dir, cfg)

@with_temp_dir
def test_fft_s2_attn(self, temp_dir):
Expand Down Expand Up @@ -111,4 +110,4 @@ def test_fft_s2_attn(self, temp_dir):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "pytorch_model.bin").exists()
check_model_output_exists(temp_dir, cfg)
7 changes: 3 additions & 4 deletions tests/e2e/patched/test_lora_llama_multipack.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import logging
import os
import unittest
from pathlib import Path

import pytest
from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_available
Expand All @@ -16,7 +15,7 @@
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault

from ..utils import with_temp_dir
from ..utils import check_model_output_exists, with_temp_dir

LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
Expand Down Expand Up @@ -76,7 +75,7 @@ def test_lora_packing(self, temp_dir):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_model_output_exists(temp_dir, cfg)

@pytest.mark.skipif(not is_auto_gptq_available(), reason="auto-gptq not available")
@with_temp_dir
Expand Down Expand Up @@ -126,4 +125,4 @@ def test_lora_gptq_packed(self, temp_dir):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_model_output_exists(temp_dir, cfg)
7 changes: 3 additions & 4 deletions tests/e2e/patched/test_mistral_samplepack.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@
import logging
import os
import unittest
from pathlib import Path

from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault

from ..utils import with_temp_dir
from ..utils import check_model_output_exists, with_temp_dir

LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
Expand Down Expand Up @@ -69,7 +68,7 @@ def test_lora_packing(self, temp_dir):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_model_output_exists(temp_dir, cfg)

@with_temp_dir
def test_ft_packing(self, temp_dir):
Expand Down Expand Up @@ -110,4 +109,4 @@ def test_ft_packing(self, temp_dir):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "pytorch_model.bin").exists()
check_model_output_exists(temp_dir, cfg)
7 changes: 3 additions & 4 deletions tests/e2e/patched/test_mixtral_samplepack.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@
import logging
import os
import unittest
from pathlib import Path

from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault

from ..utils import with_temp_dir
from ..utils import check_model_output_exists, with_temp_dir

LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
Expand Down Expand Up @@ -66,7 +65,7 @@ def test_qlora(self, temp_dir):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_model_output_exists(temp_dir, cfg)

@with_temp_dir
def test_ft(self, temp_dir):
Expand Down Expand Up @@ -108,4 +107,4 @@ def test_ft(self, temp_dir):
"MixtralFlashAttention2"
in model.model.layers[0].self_attn.__class__.__name__
)
assert (Path(temp_dir) / "pytorch_model.bin").exists()
check_model_output_exists(temp_dir, cfg)
4 changes: 2 additions & 2 deletions tests/e2e/patched/test_phi_multipack.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault

from ..utils import with_temp_dir
from ..utils import check_model_output_exists, with_temp_dir

LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
Expand Down Expand Up @@ -120,4 +120,4 @@ def test_qlora_packed(self, temp_dir):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_model_output_exists(temp_dir, cfg)
5 changes: 2 additions & 3 deletions tests/e2e/patched/test_resume.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import os
import re
import subprocess
from pathlib import Path

from transformers.utils import is_torch_bf16_gpu_available

Expand All @@ -16,7 +15,7 @@
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault

from ..utils import most_recent_subdir
from ..utils import check_model_output_exists, most_recent_subdir

LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
Expand Down Expand Up @@ -83,7 +82,7 @@ def test_resume_lora_packed(self, temp_dir):
cli_args = TrainerCliArgs()

train(cfg=resume_cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_model_output_exists(temp_dir, cfg)

tb_log_path_1 = most_recent_subdir(temp_dir + "/runs")
cmd = f"tensorboard --inspect --logdir {tb_log_path_1}"
Expand Down
9 changes: 4 additions & 5 deletions tests/e2e/patched/test_unsloth_qlora.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"""
import logging
import os
from pathlib import Path

import pytest

Expand All @@ -13,7 +12,7 @@
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault

from ..utils import check_tensorboard
from ..utils import check_model_output_exists, check_tensorboard

LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
Expand Down Expand Up @@ -74,7 +73,7 @@ def test_unsloth_llama_qlora_fa2(self, temp_dir, sample_packing):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_model_output_exists(temp_dir, cfg)

check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
Expand Down Expand Up @@ -124,7 +123,7 @@ def test_unsloth_llama_qlora_unpacked(self, temp_dir):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_model_output_exists(temp_dir, cfg)

check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
Expand Down Expand Up @@ -179,7 +178,7 @@ def test_unsloth_llama_qlora_unpacked_no_fa2_fp16(self, temp_dir, sdp_attention)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_model_output_exists(temp_dir, cfg)

check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
Expand Down
7 changes: 3 additions & 4 deletions tests/e2e/test_embeddings_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@
import logging
import os
import unittest
from pathlib import Path

from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault

from .utils import check_tensorboard, with_temp_dir
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir

LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
Expand Down Expand Up @@ -62,7 +61,7 @@ def test_train_w_embedding_lr_scale(self, temp_dir):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
check_model_output_exists(temp_dir, cfg)

check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Loss is too high"
Expand Down Expand Up @@ -106,7 +105,7 @@ def test_train_w_embedding_lr(self, temp_dir):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
check_model_output_exists(temp_dir, cfg)

check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Loss is too high"
Expand Down
Loading

0 comments on commit 1d9d237

Please sign in to comment.