diff --git a/src/instructlab/training/async_logger.py b/src/instructlab/training/async_logger.py index 6c66a452..4ea59d93 100644 --- a/src/instructlab/training/async_logger.py +++ b/src/instructlab/training/async_logger.py @@ -1,9 +1,11 @@ -import json -import asyncio +# Standard from datetime import datetime -import aiofiles +import asyncio +import json import threading -import os + +# Third Party +import aiofiles class AsyncStructuredLogger: @@ -39,7 +41,7 @@ async def log(self, data): data["timestamp"] = datetime.now().isoformat() self.logs.append(data) await self._write_logs_to_file(data) - {{print(f"\033[92m{json.dumps(data, indent=4)}\033[0m")}} + print(f"\033[92m{json.dumps(data, indent=4)}\033[0m") async def _write_logs_to_file(self, data): """appends to the log instead of writing the whole log each time""" diff --git a/src/instructlab/training/config.py b/src/instructlab/training/config.py index 69a61910..6f7c9124 100644 --- a/src/instructlab/training/config.py +++ b/src/instructlab/training/config.py @@ -74,10 +74,10 @@ class LoraOptions(BaseModel): target_modules: list[str] = Field( default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"] ) - + quantize_data_type: QuantizeDataType = QuantizeDataType.NONE - class Config: + class Config: use_enum_values = True diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index fcd6e6d5..69bb5efb 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -11,6 +11,8 @@ # Third Party from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam from deepspeed.runtime.zero.utils import ZeRORuntimeException + +# pylint: disable=no-name-in-module from instructlab.dolomite.hf_models import GPTDolomiteForCausalLM from torch.distributed import ReduceOp, all_reduce from tqdm import tqdm @@ -20,6 +22,7 @@ # First Party from instructlab.training import config +from instructlab.training.async_logger import AsyncStructuredLogger from instructlab.training.config import ( DataProcessArgs, DeepSpeedOptions, @@ -31,12 +34,12 @@ ) from instructlab.training.token_dataset import setup_dataloader, setup_dataset from instructlab.training.tokenizer_utils import setup_tokenizer -from instructlab.training.async_logger import AsyncStructuredLogger from instructlab.training.utils import ( StreamablePopen, add_noisy_embeddings, apply_gradient_checkpointing, convert_loss_to_reduce_sum, + ensure_loadable_granite_checkpoint, patch_target_module, prepare_peft_model, prepare_universal_checkpoint_from_latest, @@ -92,13 +95,14 @@ def setup_model(args, tokenizer, train_loader, grad_accum): ) if args.is_granite: - model = GPTDolomiteForCausalLM.from_pretrained( - args.model_name_or_path, - attn_implementation="flash_attention_2", - torch_dtype=torch.bfloat16, - use_padding_free_transformer=True, - quantization_config=bnb_config, - ) + with ensure_loadable_granite_checkpoint(args.model_name_or_path) as path: + model = GPTDolomiteForCausalLM.from_pretrained( + path, + attn_implementation="flash_attention_2", + torch_dtype=torch.bfloat16, + use_padding_free_transformer=True, + quantization_config=bnb_config, + ) else: model = AutoModelForCausalLM.from_pretrained( args.model_name_or_path, diff --git a/src/instructlab/training/multipack_sampler.py b/src/instructlab/training/multipack_sampler.py index 56245bd6..254dafb2 100644 --- a/src/instructlab/training/multipack_sampler.py +++ b/src/instructlab/training/multipack_sampler.py @@ -440,7 +440,9 @@ def generate_batches(self, set_stats=False): # remove indices where the entries are longer than batch max length indices = indices[self.lengths[indices] <= self.batch_max_length] if len(indices) < len(self.lengths): - print(f"\033[33mDropping {len(self.lengths) - len(indices)} samples longer than batch_max_length. Ensure that the right max_batch_length is used during data processing.\033[0m") + print( + f"\033[33mDropping {len(self.lengths) - len(indices)} samples longer than batch_max_length. Ensure that the right max_batch_length is used during data processing.\033[0m" + ) lengths = self.lengths[indices] lengths_cumsum = np.cumsum(lengths) diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py index dd50f6d1..b90f6aa3 100644 --- a/src/instructlab/training/utils.py +++ b/src/instructlab/training/utils.py @@ -1,6 +1,8 @@ # Standard +from contextlib import contextmanager from functools import partial from pathlib import Path +from tempfile import TemporaryDirectory, mktemp from typing import Any, List, Optional import importlib import inspect @@ -14,8 +16,14 @@ import warnings # Third Party -from instructlab.dolomite.hf_models import export_to_huggingface +# pylint: disable=no-name-in-module +from instructlab.dolomite.hf_models import ( + GPTDolomiteConfig, + export_to_huggingface, + import_from_huggingface, +) from rich.logging import RichHandler +from safetensors.torch import save_file from torch import distributed as dist from torch.distributed import get_rank, is_initialized from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( @@ -39,7 +47,7 @@ def retrieve_chat_template(chat_tmpl_path): spec.loader.exec_module(module) SPECIAL_TOKENS = module.SPECIAL_TOKENS CHAT_TEMPLATE = module.CHAT_TEMPLATE - except: + except: # pylint: disable=bare-except sys.exit(f"Invalid chat template path: {chat_tmpl_path}") return CHAT_TEMPLATE, SPECIAL_TOKENS @@ -473,6 +481,30 @@ class UniversalCheckpointArgs: log_rank_0(f"Preparing universal checkpoint took {time.time() - start} seconds") +@contextmanager +def ensure_loadable_granite_checkpoint(model_name_or_path: str): + if not dist.is_initialized() or dist.get_rank() == 0: + try: + GPTDolomiteConfig.from_pretrained(model_name_or_path) + yield model_name_or_path + except: # pylint: disable=bare-except + log_rank_0( + f"\033[93mModel saved in {model_name_or_path} requires conversion \033[0m", + to_print=True, + ) + # if the load failed then it must not be a granite + # for now just assume its a llama + # with TemporaryDirectory("w") as tmpdir: + # make a temp directory name, but do not create it + tmpdir = mktemp() + import_from_huggingface(model_name_or_path, tmpdir) + yield tmpdir + shutil.rmtree(tmpdir, ignore_errors=True) + + if dist.is_initialized(): + dist.barrier() + + # this function is for supporting gradient checkpointing for padding free # dolomite def apply_gradient_checkpointing( @@ -608,13 +640,6 @@ def save_hf_format_ds(args, model, tokenizer, samples_seen, convert_granite=True output_config_file = output_dir / CONFIG_NAME if args.is_granite and convert_granite: - # guarded import - # Standard - from tempfile import TemporaryDirectory - - # Third Party - from safetensors.torch import save_file - with TemporaryDirectory("w") as tmpdir: save_file(model_state, Path(tmpdir) / WEIGHTS_NAME) model_to_save.config.to_json_file(Path(tmpdir) / CONFIG_NAME)