diff --git a/examples/hymba/fft-1.5b.yml b/examples/hymba/fft-1.5b.yml
new file mode 100644
index 0000000000..e11a08ae66
--- /dev/null
+++ b/examples/hymba/fft-1.5b.yml
@@ -0,0 +1,58 @@
+base_model: nvidia/Hymba-1.5B-Base
+
+load_in_8bit: false
+load_in_4bit: false
+strict: false
+
+datasets:
+ - path: tatsu-lab/alpaca
+ type: alpaca
+dataset_prepared_path: last_run_prepared
+val_set_size: 0.05
+output_dir: ./outputs/out
+
+sequence_len: 2048
+sample_packing: true
+pad_to_sequence_len: true
+
+wandb_project:
+wandb_entity:
+wandb_watch:
+wandb_name:
+wandb_log_model:
+
+gradient_accumulation_steps: 2
+micro_batch_size: 2
+num_epochs: 1
+optimizer: paged_adamw_8bit
+lr_scheduler: cosine
+learning_rate: 2e-5
+
+train_on_inputs: false
+group_by_length: false
+bf16: auto
+fp16:
+tf32: false
+
+trust_remote_code: true
+
+gradient_checkpointing: true
+gradient_checkpointing_kwargs:
+ use_reentrant: false
+early_stopping_patience:
+resume_from_checkpoint:
+logging_steps: 1
+xformers_attention:
+flash_attention: true
+
+warmup_steps: 5
+evals_per_epoch: 2
+eval_table_size:
+saves_per_epoch: 1
+debug:
+deepspeed:
+weight_decay: 0.0
+fsdp:
+fsdp_config:
+special_tokens:
+ pad_token: <|end_of_text|>
diff --git a/examples/hymba/qlora-1.5b.yml b/examples/hymba/qlora-1.5b.yml
new file mode 100644
index 0000000000..472f8706fb
--- /dev/null
+++ b/examples/hymba/qlora-1.5b.yml
@@ -0,0 +1,73 @@
+base_model: nvidia/Hymba-1.5B-Base
+
+load_in_8bit: false
+load_in_4bit: True
+strict: false
+
+datasets:
+ - path: tatsu-lab/alpaca
+ type: alpaca
+dataset_prepared_path: last_run_prepared
+val_set_size: 0.05
+output_dir: ./outputs/out
+
+sequence_len: 2048
+sample_packing: true
+pad_to_sequence_len: true
+
+adapter: qlora
+lora_r: 32
+lora_alpha: 16
+lora_dropout: 0.05
+lora_target_linear: true
+lora_fan_in_fan_out:
+lora_target_modules:
+ - gate_proj
+ - down_proj
+ - up_proj
+ - q_proj
+ - v_proj
+ - k_proj
+ - o_proj
+
+wandb_project:
+wandb_entity:
+wandb_watch:
+wandb_name:
+wandb_log_model:
+
+gradient_accumulation_steps: 2
+micro_batch_size: 2
+num_epochs: 1
+optimizer: paged_adamw_8bit
+lr_scheduler: cosine
+learning_rate: 2e-5
+
+train_on_inputs: false
+group_by_length: false
+bf16: auto
+fp16:
+tf32: false
+
+trust_remote_code: true
+
+gradient_checkpointing: true
+gradient_checkpointing_kwargs:
+ use_reentrant: false
+early_stopping_patience:
+resume_from_checkpoint:
+logging_steps: 1
+xformers_attention:
+flash_attention: true
+
+warmup_steps: 5
+evals_per_epoch: 2
+eval_table_size:
+saves_per_epoch: 1
+debug:
+deepspeed:
+weight_decay: 0.0
+fsdp:
+fsdp_config:
+special_tokens:
+ pad_token: <|end_of_text|>
diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py
index 3ee89d2e5c..66e615516c 100644
--- a/src/axolotl/monkeypatch/multipack.py
+++ b/src/axolotl/monkeypatch/multipack.py
@@ -25,6 +25,7 @@
"gemmoe",
"starcoder2",
"deepseek_v2",
+ "hymba",
]
diff --git a/src/axolotl/train.py b/src/axolotl/train.py
index c8576f1b48..6ca5edf985 100644
--- a/src/axolotl/train.py
+++ b/src/axolotl/train.py
@@ -23,7 +23,12 @@
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.freeze import freeze_layers_except
-from axolotl.utils.models import load_model, load_processor, load_tokenizer
+from axolotl.utils.models import (
+ load_model,
+ load_model_config,
+ load_processor,
+ load_tokenizer,
+)
from axolotl.utils.trainer import setup_trainer
try:
@@ -145,7 +150,11 @@ def train(
os.makedirs(cfg.output_dir, exist_ok=True)
tokenizer.save_pretrained(str(Path(cfg.output_dir)))
if hasattr(model, "config"):
- model.config.save_pretrained(str(Path(cfg.output_dir)))
+ try:
+ model.config.save_pretrained(str(Path(cfg.output_dir)))
+ except TypeError: # required to deal with Hymba in its current state
+ model_config = load_model_config(cfg)
+ model_config.save_pretrained(str(Path(cfg.output_dir)))
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if cfg.local_rank == 0:
diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py
index ffe5e24853..ac923164b5 100644
--- a/src/axolotl/utils/chat_templates.py
+++ b/src/axolotl/utils/chat_templates.py
@@ -31,6 +31,7 @@
"qwen_25": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- message.content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
"exaone": "{% for message in messages %}{% if loop.first and message['role'] != 'system' %}{{ '[|system|][|endofturn|]\n' }}{% endif %}{{ '[|' + message['role'] + '|]' + message['content'] }}{% if message['role'] == 'user' %}{{ '\n' }}{% else %}{{ '[|endofturn|]\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '[|assistant|]' }}{% endif %}",
"metharme": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'Enter RP mode. You shall reply to the user while staying in character. Your responses must be detailed, creative, immersive, and drive the scenario forward.' %}{% endif %}{{ '<|system|>' + system_message }}{% for message in loop_messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|user|>' + content.strip() }}{% elif message['role'] == 'assistant' %}{{ '<|model|>' + content.strip() }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|model|>' }}{% else %}{{ eos_token }}{% endif %}",
+ "hymba": "{{'System'}}{% for message in messages %}{% if message['role'] == 'system' %}{{'\n' + message['content'].strip()}}{% if tools or contexts %}{{'\n'}}{% endif %}{% endif %}{% endfor %}{% if tools %}{% for tool in tools %}{{ '\n ' + tool|tojson + ' ' }}{% endfor %}{% endif %}{% if contexts %}{% if tools %}{{'\n'}}{% endif %}{% for context in contexts %}{{ '\n ' + context.strip() + ' ' }}{% endfor %}{% endif %}{{'\n\n'}}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User\n' + message['content'].strip() + '\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant\n' + message['content'].strip() + '\n' }}{% elif message['role'] == 'tool' %}{{ 'Tool\n' + message['content'].strip() + '\n' }}{% endif %}{% endfor %}{%- if add_generation_prompt %}{{'Assistant\n'}}{%- endif %}",
}
diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py
index 3671e1bb93..1826166864 100644
--- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py
+++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py
@@ -60,6 +60,7 @@ class ChatTemplate(str, Enum):
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
exaone = "exaone" # pylint: disable=invalid-name
metharme = "metharme" # pylint: disable=invalid-name
+ hymba = "hymba" # pylint: disable=invalid-name
class DeprecatedParameters(BaseModel):
@@ -1581,3 +1582,19 @@ def check_adopt_torch_version(cls, data):
"ADOPT optimizer is incompatible with torch version < 2.5.1"
)
return data
+
+ @model_validator(mode="before")
+ @classmethod
+ def check_hymba_torch_version(cls, data):
+ if "hymba" in data.get("base_model", {}).lower():
+ env_capabilities = data.get("env_capabilities", {})
+ torch_version = env_capabilities.get("torch_version")
+
+ if torch_version is None:
+ import torch
+
+ torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
+
+ if version.parse(torch_version) < version.parse("2.5.0"):
+ raise ValueError("Hymba requires torch version >= 2.5")
+ return data
diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py
index a350f24295..8c8f5eaaed 100644
--- a/src/axolotl/utils/models.py
+++ b/src/axolotl/utils/models.py
@@ -420,6 +420,7 @@ def apply_patches(self) -> None:
and self.cfg.sample_packing
):
if "auto_map" in self.model_config:
+ # some model config objects are not subscriptable
try:
auto_map_config = self.model_config["auto_map"]
except TypeError:
@@ -427,6 +428,7 @@ def apply_patches(self) -> None:
has_remote_code = "AutoModelForCausalLM" in auto_map_config
else:
has_remote_code = False
+
if has_remote_code and self.cfg.trust_remote_code is False:
# if explicitly set in the YAML, we should prefer that, for example if explicitly disabled
has_remote_code = self.cfg.trust_remote_code
@@ -1155,7 +1157,7 @@ def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
and not skip_move_to_device
):
# TODO revaldate this conditional
- self.model.to(f"{str(get_device_type())}:{self.cfg.local_rank}")
+ self.model.to(f"{str(get_device_type())}: {self.cfg.local_rank}")
if get_device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
setattr(self.model, "is_parallelizable", True)
diff --git a/tests/e2e/test_optimizers.py b/tests/e2e/test_optimizers.py
index 2317bfb97a..58e2493c25 100644
--- a/tests/e2e/test_optimizers.py
+++ b/tests/e2e/test_optimizers.py
@@ -67,8 +67,8 @@ def test_optimi_adamw(self, temp_dir):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
- @with_temp_dir
@require_torch_2_5_1
+ @with_temp_dir
def test_adopt_adamw(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
diff --git a/tests/e2e/test_packing_loss.py b/tests/e2e/test_packing_loss.py
index dd0af32f3c..43f623ca6c 100644
--- a/tests/e2e/test_packing_loss.py
+++ b/tests/e2e/test_packing_loss.py
@@ -14,7 +14,7 @@
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
-from .utils import check_tensorboard, with_temp_dir
+from .utils import check_tensorboard, require_torch_2_5_1, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -68,3 +68,129 @@ def test_loss_packed(self, temp_dir):
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
)
+
+
+class TestPackedHymba(unittest.TestCase):
+ """
+ Test case for Packed training of hymba models
+ """
+
+ @require_torch_2_5_1
+ @with_temp_dir
+ def test_loss_packed(self, temp_dir):
+ # pylint: disable=duplicate-code
+ cfg = DictDefault(
+ {
+ "base_model": "nvidia/Hymba-1.5B-Base",
+ "trust_remote_code": True,
+ "load_in_4bit": True,
+ "adapter": "qlora",
+ "lora_r": 32,
+ "lora_alpha": 16,
+ "lora_dropout": 0.05,
+ "lora_target_modules": [
+ "gate_proj",
+ "down_proj",
+ "up_proj",
+ "q_proj",
+ "v_proj",
+ "k_proj",
+ "o_proj",
+ ],
+ "sequence_len": 1024,
+ "sample_packing": True,
+ "flash_attention": True,
+ "val_set_size": 0.0,
+ "datasets": [
+ {
+ "path": "vicgalle/alpaca-gpt4",
+ "type": "alpaca",
+ },
+ ],
+ "num_epochs": 1,
+ "micro_batch_size": 2,
+ "gradient_accumulation_steps": 4,
+ "output_dir": temp_dir,
+ "learning_rate": 0.00001,
+ "optimizer": "adamw_torch",
+ "lr_scheduler": "cosine",
+ "max_steps": 5,
+ "use_tensorboard": True,
+ }
+ )
+ if is_torch_bf16_gpu_available():
+ cfg.bf16 = True
+ else:
+ cfg.fp16 = True
+ normalize_config(cfg)
+ cli_args = TrainerCliArgs()
+ dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
+
+ train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
+
+ check_tensorboard(
+ temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
+ )
+
+
+class TestUnpackedHymba(unittest.TestCase):
+ """
+ Test case for Unpacked training of hymba models
+ """
+
+ @require_torch_2_5_1
+ @with_temp_dir
+ def test_loss_unpacked(self, temp_dir):
+ # pylint: disable=duplicate-code
+ cfg = DictDefault(
+ {
+ "base_model": "nvidia/Hymba-1.5B-Base",
+ "trust_remote_code": True,
+ "load_in_4bit": True,
+ "adapter": "qlora",
+ "lora_r": 32,
+ "lora_alpha": 16,
+ "lora_dropout": 0.05,
+ "lora_target_modules": [
+ "gate_proj",
+ "down_proj",
+ "up_proj",
+ "q_proj",
+ "v_proj",
+ "k_proj",
+ "o_proj",
+ ],
+ "sequence_len": 1024,
+ "sample_packing": False,
+ "flash_attention": True,
+ "val_set_size": 0.0,
+ "datasets": [
+ {
+ "path": "vicgalle/alpaca-gpt4",
+ "type": "alpaca",
+ },
+ ],
+ "num_epochs": 1,
+ "micro_batch_size": 2,
+ "gradient_accumulation_steps": 4,
+ "output_dir": temp_dir,
+ "learning_rate": 0.00001,
+ "optimizer": "adamw_torch",
+ "lr_scheduler": "cosine",
+ "max_steps": 5,
+ "use_tensorboard": True,
+ }
+ )
+ if is_torch_bf16_gpu_available():
+ cfg.bf16 = True
+ else:
+ cfg.fp16 = True
+ normalize_config(cfg)
+ cli_args = TrainerCliArgs()
+ dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
+
+ train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
+
+ check_tensorboard(
+ temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
+ )