From 0c8b1d824aeaa41f27b4ed4e73e8906751180322 Mon Sep 17 00:00:00 2001 From: Chirag Jain Date: Sat, 16 Nov 2024 07:05:50 +0530 Subject: [PATCH] Update `get_unpad_data` patching for multipack (#2013) * Update `get_unpad_data` patching for multipack * Update src/axolotl/utils/models.py * Update src/axolotl/utils/models.py * Add test case --------- Co-authored-by: Wing Lian Co-authored-by: Wing Lian --- src/axolotl/monkeypatch/multipack.py | 72 ++++++---------------------- src/axolotl/utils/models.py | 9 +++- tests/e2e/test_llama.py | 66 +++++++++++++++++++++++++ 3 files changed, 89 insertions(+), 58 deletions(-) create mode 100644 tests/e2e/test_llama.py diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index 85101cd3c4..3ee89d2e5c 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -1,4 +1,5 @@ """multipack patching for v2 of sample packing""" + import importlib import transformers @@ -27,71 +28,28 @@ ] -def patch_for_multipack(model_type, model_name=None, is_remote_code=False): - if model_type == "gemmoe": - patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe") - elif model_type == "deepseek_v2": - patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek") - elif hasattr(transformers, "modeling_flash_attention_utils") and not is_remote_code: +def patch_for_multipack(model_type, model_name=None, has_remote_code=False): + if has_remote_code: + patch_remote(model_name) + elif hasattr(transformers, "modeling_flash_attention_utils"): transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access get_unpad_data ) - if model_type == "mixtral" and is_deepspeed_zero3_enabled(): - patch_mixtral_moe_forward_zero3() - return - # retain for legacy - if model_type == "mixtral": - transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - if is_deepspeed_zero3_enabled(): - patch_mixtral_moe_forward_zero3() - elif model_type == "llama": - if hasattr(transformers.models.llama.modeling_llama, "_get_unpad_data"): - transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - elif model_type == "mistral": - if hasattr(transformers.models.mistral.modeling_mistral, "_get_unpad_data"): - transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - elif model_type == "qwen2": - transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - elif model_type == "qwen2_moe": - transformers.models.qwen2_moe.modeling_qwen2_moe._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - elif model_type == "falcon": - transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - elif model_type == "phi": - transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - elif model_type == "gemma": - transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - elif model_type == "gemma2": - transformers.models.gemma2.modeling_gemma2._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - elif model_type == "starcoder2": - transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) + if model_type == "mixtral" and is_deepspeed_zero3_enabled(): + patch_mixtral_moe_forward_zero3() -def patch_remote(model_name, config_name, modeling_name): +def patch_remote(model_name): model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) # we need to load the model here in order for modeling_* to be available with init_empty_weights(): AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) - module_name = model_config.__class__.__module__.replace(config_name, modeling_name) + parts = model_config.__class__.__module__.split(".") + parts[-1] = parts[-1].replace("configuration_", "modeling_", 1) + module_name = ".".join(parts) modeling_arch = importlib.import_module(module_name) - modeling_arch._get_unpad_data = get_unpad_data # pylint: disable=protected-access + if hasattr(modeling_arch, "_get_unpad_data"): + modeling_arch._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index db66c65f25..75c93fa2a5 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -395,10 +395,17 @@ def apply_patches(self) -> None: and self.cfg.flash_attention and self.cfg.sample_packing ): + has_remote_code = ( + "auto_map" in self.model_config + and "AutoModelForCausalLM" in self.model_config["auto_map"] + ) + if has_remote_code and self.cfg.trust_remote_code is False: + # if explicitly set in the YAML, we should prefer that, for example if explicitly disabled + has_remote_code = self.cfg.trust_remote_code patch_for_multipack( self.cfg.model_config_type, model_name=self.cfg.base_model, - is_remote_code=self.cfg.trust_remote_code, + has_remote_code=has_remote_code, ) if self.cfg.is_llama_derived_model: diff --git a/tests/e2e/test_llama.py b/tests/e2e/test_llama.py new file mode 100644 index 0000000000..4e885a76db --- /dev/null +++ b/tests/e2e/test_llama.py @@ -0,0 +1,66 @@ +""" +E2E tests for llama +""" + +import logging +import os +import unittest +from pathlib import Path + +from axolotl.cli import load_datasets +from axolotl.common.cli import TrainerCliArgs +from axolotl.train import train +from axolotl.utils.config import normalize_config +from axolotl.utils.dict import DictDefault + +from .utils import with_temp_dir + +LOG = logging.getLogger("axolotl.tests.e2e") +os.environ["WANDB_DISABLED"] = "true" + + +class TestLlama(unittest.TestCase): + """ + Test case for Llama models + """ + + @with_temp_dir + def test_fft_trust_remote_code(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "JackFram/llama-68m", + "tokenizer_type": "LlamaTokenizer", + "trust_remote_code": True, + "sequence_len": 512, + "val_set_size": 0.1, + "special_tokens": { + "unk_token": "", + "bos_token": "", + "eos_token": "", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "micro_batch_size": 8, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_bnb_8bit", + "lr_scheduler": "cosine", + "flash_attention": True, + "sample_packing": True, + "bf16": True, + "save_safetensors": True, + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "model.safetensors").exists()