diff --git a/examples/hymba/fft-1.5b.yml b/examples/hymba/fft-1.5b.yml new file mode 100644 index 0000000000..e11a08ae66 --- /dev/null +++ b/examples/hymba/fft-1.5b.yml @@ -0,0 +1,58 @@ +base_model: nvidia/Hymba-1.5B-Base + +load_in_8bit: false +load_in_4bit: false +strict: false + +datasets: + - path: tatsu-lab/alpaca + type: alpaca +dataset_prepared_path: last_run_prepared +val_set_size: 0.05 +output_dir: ./outputs/out + +sequence_len: 2048 +sample_packing: true +pad_to_sequence_len: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 2 +micro_batch_size: 2 +num_epochs: 1 +optimizer: paged_adamw_8bit +lr_scheduler: cosine +learning_rate: 2e-5 + +train_on_inputs: false +group_by_length: false +bf16: auto +fp16: +tf32: false + +trust_remote_code: true + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +early_stopping_patience: +resume_from_checkpoint: +logging_steps: 1 +xformers_attention: +flash_attention: true + +warmup_steps: 5 +evals_per_epoch: 2 +eval_table_size: +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: + pad_token: <|end_of_text|> diff --git a/examples/hymba/qlora-1.5b.yml b/examples/hymba/qlora-1.5b.yml new file mode 100644 index 0000000000..472f8706fb --- /dev/null +++ b/examples/hymba/qlora-1.5b.yml @@ -0,0 +1,73 @@ +base_model: nvidia/Hymba-1.5B-Base + +load_in_8bit: false +load_in_4bit: True +strict: false + +datasets: + - path: tatsu-lab/alpaca + type: alpaca +dataset_prepared_path: last_run_prepared +val_set_size: 0.05 +output_dir: ./outputs/out + +sequence_len: 2048 +sample_packing: true +pad_to_sequence_len: true + +adapter: qlora +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_fan_in_fan_out: +lora_target_modules: + - gate_proj + - down_proj + - up_proj + - q_proj + - v_proj + - k_proj + - o_proj + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 2 +micro_batch_size: 2 +num_epochs: 1 +optimizer: paged_adamw_8bit +lr_scheduler: cosine +learning_rate: 2e-5 + +train_on_inputs: false +group_by_length: false +bf16: auto +fp16: +tf32: false + +trust_remote_code: true + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +early_stopping_patience: +resume_from_checkpoint: +logging_steps: 1 +xformers_attention: +flash_attention: true + +warmup_steps: 5 +evals_per_epoch: 2 +eval_table_size: +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: + pad_token: <|end_of_text|> diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index 3ee89d2e5c..66e615516c 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -25,6 +25,7 @@ "gemmoe", "starcoder2", "deepseek_v2", + "hymba", ] diff --git a/src/axolotl/train.py b/src/axolotl/train.py index c8576f1b48..6ca5edf985 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -23,7 +23,12 @@ from axolotl.logging_config import configure_logging from axolotl.utils.dict import DictDefault from axolotl.utils.freeze import freeze_layers_except -from axolotl.utils.models import load_model, load_processor, load_tokenizer +from axolotl.utils.models import ( + load_model, + load_model_config, + load_processor, + load_tokenizer, +) from axolotl.utils.trainer import setup_trainer try: @@ -145,7 +150,11 @@ def train( os.makedirs(cfg.output_dir, exist_ok=True) tokenizer.save_pretrained(str(Path(cfg.output_dir))) if hasattr(model, "config"): - model.config.save_pretrained(str(Path(cfg.output_dir))) + try: + model.config.save_pretrained(str(Path(cfg.output_dir))) + except TypeError: # required to deal with Hymba in its current state + model_config = load_model_config(cfg) + model_config.save_pretrained(str(Path(cfg.output_dir))) # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model if cfg.local_rank == 0: diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py index ffe5e24853..ac923164b5 100644 --- a/src/axolotl/utils/chat_templates.py +++ b/src/axolotl/utils/chat_templates.py @@ -31,6 +31,7 @@ "qwen_25": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- message.content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n", "exaone": "{% for message in messages %}{% if loop.first and message['role'] != 'system' %}{{ '[|system|][|endofturn|]\n' }}{% endif %}{{ '[|' + message['role'] + '|]' + message['content'] }}{% if message['role'] == 'user' %}{{ '\n' }}{% else %}{{ '[|endofturn|]\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '[|assistant|]' }}{% endif %}", "metharme": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'Enter RP mode. You shall reply to the user while staying in character. Your responses must be detailed, creative, immersive, and drive the scenario forward.' %}{% endif %}{{ '<|system|>' + system_message }}{% for message in loop_messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|user|>' + content.strip() }}{% elif message['role'] == 'assistant' %}{{ '<|model|>' + content.strip() }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|model|>' }}{% else %}{{ eos_token }}{% endif %}", + "hymba": "{{'System'}}{% for message in messages %}{% if message['role'] == 'system' %}{{'\n' + message['content'].strip()}}{% if tools or contexts %}{{'\n'}}{% endif %}{% endif %}{% endfor %}{% if tools %}{% for tool in tools %}{{ '\n ' + tool|tojson + ' ' }}{% endfor %}{% endif %}{% if contexts %}{% if tools %}{{'\n'}}{% endif %}{% for context in contexts %}{{ '\n ' + context.strip() + ' ' }}{% endfor %}{% endif %}{{'\n\n'}}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User\n' + message['content'].strip() + '\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant\n' + message['content'].strip() + '\n' }}{% elif message['role'] == 'tool' %}{{ 'Tool\n' + message['content'].strip() + '\n' }}{% endif %}{% endfor %}{%- if add_generation_prompt %}{{'Assistant\n'}}{%- endif %}", } diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 3671e1bb93..1826166864 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -60,6 +60,7 @@ class ChatTemplate(str, Enum): tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name exaone = "exaone" # pylint: disable=invalid-name metharme = "metharme" # pylint: disable=invalid-name + hymba = "hymba" # pylint: disable=invalid-name class DeprecatedParameters(BaseModel): @@ -1581,3 +1582,19 @@ def check_adopt_torch_version(cls, data): "ADOPT optimizer is incompatible with torch version < 2.5.1" ) return data + + @model_validator(mode="before") + @classmethod + def check_hymba_torch_version(cls, data): + if "hymba" in data.get("base_model", {}).lower(): + env_capabilities = data.get("env_capabilities", {}) + torch_version = env_capabilities.get("torch_version") + + if torch_version is None: + import torch + + torch_version = str(torch.__version__).split("+", maxsplit=1)[0] + + if version.parse(torch_version) < version.parse("2.5.0"): + raise ValueError("Hymba requires torch version >= 2.5") + return data diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index a350f24295..8c8f5eaaed 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -420,6 +420,7 @@ def apply_patches(self) -> None: and self.cfg.sample_packing ): if "auto_map" in self.model_config: + # some model config objects are not subscriptable try: auto_map_config = self.model_config["auto_map"] except TypeError: @@ -427,6 +428,7 @@ def apply_patches(self) -> None: has_remote_code = "AutoModelForCausalLM" in auto_map_config else: has_remote_code = False + if has_remote_code and self.cfg.trust_remote_code is False: # if explicitly set in the YAML, we should prefer that, for example if explicitly disabled has_remote_code = self.cfg.trust_remote_code @@ -1155,7 +1157,7 @@ def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: and not skip_move_to_device ): # TODO revaldate this conditional - self.model.to(f"{str(get_device_type())}:{self.cfg.local_rank}") + self.model.to(f"{str(get_device_type())}: {self.cfg.local_rank}") if get_device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1: setattr(self.model, "is_parallelizable", True) diff --git a/tests/e2e/test_optimizers.py b/tests/e2e/test_optimizers.py index 2317bfb97a..58e2493c25 100644 --- a/tests/e2e/test_optimizers.py +++ b/tests/e2e/test_optimizers.py @@ -67,8 +67,8 @@ def test_optimi_adamw(self, temp_dir): train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) assert (Path(temp_dir) / "adapter_model.bin").exists() - @with_temp_dir @require_torch_2_5_1 + @with_temp_dir def test_adopt_adamw(self, temp_dir): # pylint: disable=duplicate-code cfg = DictDefault( diff --git a/tests/e2e/test_packing_loss.py b/tests/e2e/test_packing_loss.py index dd0af32f3c..43f623ca6c 100644 --- a/tests/e2e/test_packing_loss.py +++ b/tests/e2e/test_packing_loss.py @@ -14,7 +14,7 @@ from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from .utils import check_tensorboard, with_temp_dir +from .utils import check_tensorboard, require_torch_2_5_1, with_temp_dir LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" @@ -68,3 +68,129 @@ def test_loss_packed(self, temp_dir): check_tensorboard( temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high" ) + + +class TestPackedHymba(unittest.TestCase): + """ + Test case for Packed training of hymba models + """ + + @require_torch_2_5_1 + @with_temp_dir + def test_loss_packed(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "nvidia/Hymba-1.5B-Base", + "trust_remote_code": True, + "load_in_4bit": True, + "adapter": "qlora", + "lora_r": 32, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_modules": [ + "gate_proj", + "down_proj", + "up_proj", + "q_proj", + "v_proj", + "k_proj", + "o_proj", + ], + "sequence_len": 1024, + "sample_packing": True, + "flash_attention": True, + "val_set_size": 0.0, + "datasets": [ + { + "path": "vicgalle/alpaca-gpt4", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "micro_batch_size": 2, + "gradient_accumulation_steps": 4, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + "max_steps": 5, + "use_tensorboard": True, + } + ) + if is_torch_bf16_gpu_available(): + cfg.bf16 = True + else: + cfg.fp16 = True + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + + check_tensorboard( + temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high" + ) + + +class TestUnpackedHymba(unittest.TestCase): + """ + Test case for Unpacked training of hymba models + """ + + @require_torch_2_5_1 + @with_temp_dir + def test_loss_unpacked(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "nvidia/Hymba-1.5B-Base", + "trust_remote_code": True, + "load_in_4bit": True, + "adapter": "qlora", + "lora_r": 32, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_modules": [ + "gate_proj", + "down_proj", + "up_proj", + "q_proj", + "v_proj", + "k_proj", + "o_proj", + ], + "sequence_len": 1024, + "sample_packing": False, + "flash_attention": True, + "val_set_size": 0.0, + "datasets": [ + { + "path": "vicgalle/alpaca-gpt4", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "micro_batch_size": 2, + "gradient_accumulation_steps": 4, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + "max_steps": 5, + "use_tensorboard": True, + } + ) + if is_torch_bf16_gpu_available(): + cfg.bf16 = True + else: + cfg.fp16 = True + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + + check_tensorboard( + temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high" + )