Skip to content

Commit

Permalink
Merge branch 'main' into pixtral_integration2
Browse files Browse the repository at this point in the history
  • Loading branch information
bursteratom authored Dec 27, 2024
2 parents ce91855 + 7a38dbe commit ead8d21
Show file tree
Hide file tree
Showing 9 changed files with 308 additions and 218 deletions.
27 changes: 27 additions & 0 deletions deepspeed_configs/zero1_torch_compile.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
{
"zero_optimization": {
"stage": 1,
"overlap_comm": true
},
"bf16": {
"enabled": "auto"
},
"fp16": {
"enabled": "auto",
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"compile": {
"disable": false,
"backend": "inductor"
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,4 @@ antlr4-python3-runtime==4.13.2
torchao==0.7.0
schedulefree==1.3.0

axolotl-contribs-lgpl==0.0.1b2
axolotl-contribs-lgpl==0.0.2
4 changes: 2 additions & 2 deletions src/axolotl/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def evaluate(config: str, accelerate: bool, **kwargs):
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--accelerate/--no-accelerate",
default=True,
default=False,
help="Use accelerate launch for multi-GPU inference",
)
@click.option(
Expand Down Expand Up @@ -124,7 +124,7 @@ def inference(
if lora_model_dir:
kwargs["lora_model_dir"] = lora_model_dir
if base_model:
kwargs["output_dir"] = base_model
kwargs["base_model"] = base_model

if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.inference"]
Expand Down
9 changes: 6 additions & 3 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
GCCallback,
GPUStatsCallback,
LossWatchDogCallback,
SaveAxolotlConfigtoWandBCallback,
Expand All @@ -68,7 +69,7 @@
)
from axolotl.utils.callbacks.lisa import lisa_callback_factory
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
from axolotl.utils.chat_templates import get_chat_template
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
Expand Down Expand Up @@ -1453,6 +1454,8 @@ def get_callbacks(self):
if self.cfg.loss_watchdog_threshold is not None:
callbacks.append(LossWatchDogCallback(self.cfg))

if self.cfg.gc_steps:
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
callbacks.append(SaveModelCallback())

return callbacks
Expand Down Expand Up @@ -1832,8 +1835,8 @@ def build(self, total_num_steps):
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
if self.cfg.chat_template:
training_arguments_kwargs["chat_template"] = get_chat_template(
self.cfg.chat_template,
training_arguments_kwargs["chat_template"] = get_chat_template_from_config(
cfg=self.cfg,
tokenizer=self.tokenizer,
)

Expand Down
16 changes: 15 additions & 1 deletion src/axolotl/train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""

import inspect
import os
import signal
import sys
Expand Down Expand Up @@ -126,7 +127,20 @@ def train(
)

if cfg.fix_untrained_tokens:
fix_untrained_tokens(model, tokenizer, train_dataset)
# check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
sig = inspect.signature(fix_untrained_tokens)
# if the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
if "token_ids_to_fix" in sig.parameters and isinstance(
cfg.fix_untrained_tokens, list
):
fix_untrained_tokens(
model,
tokenizer,
train_dataset,
token_ids_to_fix=cfg.fix_untrained_tokens,
)
else:
fix_untrained_tokens(model, tokenizer, train_dataset)
if cfg.local_rank == 0:
model.save_pretrained(
str(Path(cfg.output_dir)), safe_serialization=safe_serialization
Expand Down
15 changes: 15 additions & 0 deletions src/axolotl/utils/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import gc
import logging
import math
import os
Expand Down Expand Up @@ -842,3 +843,17 @@ def on_train_end( # pylint: disable=unused-argument
):
control.should_save = True
return control


class GCCallback(TrainerCallback):
"""Callback to garbage collect torch cache"""

def __init__(self, gc_steps=None):
self.gc_steps = gc_steps

def on_step_end(
self, args, state, control, **kwargs # pylint: disable=unused-argument
):
if state.global_step % self.gc_steps == 0:
torch.cuda.empty_cache()
gc.collect()
4 changes: 3 additions & 1 deletion src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,8 @@ class Config:
loss_watchdog_threshold: Optional[float] = None
loss_watchdog_patience: Optional[int] = None

gc_steps: Optional[int] = None

bf16: Optional[Union[Literal["auto"], bool]] = "auto"
fp16: Optional[bool] = None
bfloat16: Optional[bool] = None # for non-AMP cases
Expand Down Expand Up @@ -795,7 +797,7 @@ class Config:
chat_template_jinja: Optional[str] = None
default_system_message: Optional[str] = None

fix_untrained_tokens: Optional[bool] = None
fix_untrained_tokens: Optional[Union[int, List[int]]] = None

# INTERNALS - document for now, generally not set externally
is_preprocess: Optional[bool] = None
Expand Down
Loading

0 comments on commit ead8d21

Please sign in to comment.