Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add helper to verify the correct model output file exists #2245

Merged
merged 6 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/axolotl/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def decorator(function):
field_type = next(
t for t in get_args(field_type) if not isinstance(t, NoneType)
)

if field_type == bool:
field_name = field.name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}"
Expand Down
8 changes: 4 additions & 4 deletions tests/e2e/integrations/test_cut_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
Simple end-to-end test for Cut Cross Entropy integration
"""

from pathlib import Path

import pytest

from axolotl.cli import load_datasets
Expand All @@ -13,6 +11,8 @@
from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.utils.dict import DictDefault

from ..utils import check_model_output_exists

# pylint: disable=duplicate-code


Expand Down Expand Up @@ -67,7 +67,7 @@ def test_llama_w_cce(self, min_cfg, temp_dir):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
else:
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
check_model_output_exists(temp_dir, cfg)

@pytest.mark.parametrize(
"attention_type",
Expand Down Expand Up @@ -95,4 +95,4 @@ def test_llama_w_cce_and_attention(self, min_cfg, temp_dir, attention_type):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
else:
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
check_model_output_exists(temp_dir, cfg)
7 changes: 4 additions & 3 deletions tests/e2e/integrations/test_liger.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""
Simple end-to-end test for Liger integration
"""
from pathlib import Path

from e2e.utils import require_torch_2_4_1

Expand All @@ -11,6 +10,8 @@
from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.utils.dict import DictDefault

from ..utils import check_model_output_exists


class LigerIntegrationTestCase:
"""
Expand Down Expand Up @@ -60,7 +61,7 @@ def test_llama_wo_flce(self, temp_dir):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
check_model_output_exists(temp_dir, cfg)

@require_torch_2_4_1
def test_llama_w_flce(self, temp_dir):
Expand Down Expand Up @@ -105,4 +106,4 @@ def test_llama_w_flce(self, temp_dir):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
check_model_output_exists(temp_dir, cfg)
7 changes: 3 additions & 4 deletions tests/e2e/patched/test_4d_multipack_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@
import logging
import os
import unittest
from pathlib import Path

from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault

from ..utils import require_torch_2_3_1, with_temp_dir
from ..utils import check_model_output_exists, require_torch_2_3_1, with_temp_dir

LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
Expand Down Expand Up @@ -67,7 +66,7 @@ def test_sdp_lora_packing(self, temp_dir):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_model_output_exists(temp_dir, cfg)

@with_temp_dir
def test_torch_lora_packing(self, temp_dir):
Expand Down Expand Up @@ -111,4 +110,4 @@ def test_torch_lora_packing(self, temp_dir):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_model_output_exists(temp_dir, cfg)
5 changes: 2 additions & 3 deletions tests/e2e/patched/test_fa_xentropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import logging
import os
from pathlib import Path

import pytest
from transformers.utils import is_torch_bf16_gpu_available
Expand All @@ -15,7 +14,7 @@
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault

from ..utils import check_tensorboard
from ..utils import check_model_output_exists, check_tensorboard

LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
Expand Down Expand Up @@ -82,7 +81,7 @@ def test_lora_packing_fa_cross_entropy(self, temp_dir, gradient_accumulation_ste
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_model_output_exists(temp_dir, cfg)

check_tensorboard(
temp_dir + "/runs", "train/train_loss", 1.5, "Train Loss is too high"
Expand Down
7 changes: 3 additions & 4 deletions tests/e2e/patched/test_falcon_samplepack.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@
import logging
import os
import unittest
from pathlib import Path

from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault

from ..utils import with_temp_dir
from ..utils import check_model_output_exists, with_temp_dir

LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
Expand Down Expand Up @@ -69,7 +68,7 @@ def test_qlora(self, temp_dir):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_model_output_exists(temp_dir, cfg)

@with_temp_dir
def test_ft(self, temp_dir):
Expand Down Expand Up @@ -109,4 +108,4 @@ def test_ft(self, temp_dir):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "pytorch_model.bin").exists()
check_model_output_exists(temp_dir, cfg)
5 changes: 2 additions & 3 deletions tests/e2e/patched/test_fused_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import logging
import os
import unittest
from pathlib import Path

import pytest
from transformers.utils import is_torch_bf16_gpu_available
Expand All @@ -16,7 +15,7 @@
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault

from ..utils import with_temp_dir
from ..utils import check_model_output_exists, with_temp_dir

LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
Expand Down Expand Up @@ -73,4 +72,4 @@ def test_fft_packing(self, temp_dir):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "pytorch_model.bin").exists()
check_model_output_exists(temp_dir, cfg)
7 changes: 3 additions & 4 deletions tests/e2e/patched/test_llama_s2_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import logging
import os
import unittest
from pathlib import Path

import pytest

Expand All @@ -15,7 +14,7 @@
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault

from ..utils import with_temp_dir
from ..utils import check_model_output_exists, with_temp_dir

LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
Expand Down Expand Up @@ -71,7 +70,7 @@ def test_lora_s2_attn(self, temp_dir):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_model_output_exists(temp_dir, cfg)

@with_temp_dir
def test_fft_s2_attn(self, temp_dir):
Expand Down Expand Up @@ -111,4 +110,4 @@ def test_fft_s2_attn(self, temp_dir):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "pytorch_model.bin").exists()
check_model_output_exists(temp_dir, cfg)
7 changes: 3 additions & 4 deletions tests/e2e/patched/test_lora_llama_multipack.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import logging
import os
import unittest
from pathlib import Path

import pytest
from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_available
Expand All @@ -16,7 +15,7 @@
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault

from ..utils import with_temp_dir
from ..utils import check_model_output_exists, with_temp_dir

LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
Expand Down Expand Up @@ -76,7 +75,7 @@ def test_lora_packing(self, temp_dir):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_model_output_exists(temp_dir, cfg)

@pytest.mark.skipif(not is_auto_gptq_available(), reason="auto-gptq not available")
@with_temp_dir
Expand Down Expand Up @@ -126,4 +125,4 @@ def test_lora_gptq_packed(self, temp_dir):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_model_output_exists(temp_dir, cfg)
7 changes: 3 additions & 4 deletions tests/e2e/patched/test_mistral_samplepack.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@
import logging
import os
import unittest
from pathlib import Path

from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault

from ..utils import with_temp_dir
from ..utils import check_model_output_exists, with_temp_dir

LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
Expand Down Expand Up @@ -69,7 +68,7 @@ def test_lora_packing(self, temp_dir):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_model_output_exists(temp_dir, cfg)

@with_temp_dir
def test_ft_packing(self, temp_dir):
Expand Down Expand Up @@ -110,4 +109,4 @@ def test_ft_packing(self, temp_dir):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "pytorch_model.bin").exists()
check_model_output_exists(temp_dir, cfg)
7 changes: 3 additions & 4 deletions tests/e2e/patched/test_mixtral_samplepack.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@
import logging
import os
import unittest
from pathlib import Path

from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault

from ..utils import with_temp_dir
from ..utils import check_model_output_exists, with_temp_dir

LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
Expand Down Expand Up @@ -66,7 +65,7 @@ def test_qlora(self, temp_dir):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_model_output_exists(temp_dir, cfg)

@with_temp_dir
def test_ft(self, temp_dir):
Expand Down Expand Up @@ -108,4 +107,4 @@ def test_ft(self, temp_dir):
"MixtralFlashAttention2"
in model.model.layers[0].self_attn.__class__.__name__
)
assert (Path(temp_dir) / "pytorch_model.bin").exists()
check_model_output_exists(temp_dir, cfg)
7 changes: 3 additions & 4 deletions tests/e2e/patched/test_phi_multipack.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@
import logging
import os
import unittest
from pathlib import Path

from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault

from ..utils import with_temp_dir
from ..utils import check_model_output_exists, with_temp_dir

LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
Expand Down Expand Up @@ -69,7 +68,7 @@ def test_ft_packed(self, temp_dir):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "pytorch_model.bin").exists()
check_model_output_exists(temp_dir, cfg)

@with_temp_dir
def test_qlora_packed(self, temp_dir):
Expand Down Expand Up @@ -120,4 +119,4 @@ def test_qlora_packed(self, temp_dir):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_model_output_exists(temp_dir, cfg)
5 changes: 2 additions & 3 deletions tests/e2e/patched/test_resume.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import os
import re
import subprocess
from pathlib import Path

from transformers.utils import is_torch_bf16_gpu_available

Expand All @@ -16,7 +15,7 @@
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault

from ..utils import most_recent_subdir
from ..utils import check_model_output_exists, most_recent_subdir

LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
Expand Down Expand Up @@ -83,7 +82,7 @@ def test_resume_lora_packed(self, temp_dir):
cli_args = TrainerCliArgs()

train(cfg=resume_cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_model_output_exists(temp_dir, cfg)

tb_log_path_1 = most_recent_subdir(temp_dir + "/runs")
cmd = f"tensorboard --inspect --logdir {tb_log_path_1}"
Expand Down
Loading