Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert To Dolomite Checkpoint If Needed #67

Merged
merged 3 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions src/instructlab/training/async_logger.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import json
import asyncio
# Standard
from datetime import datetime
import aiofiles
import asyncio
import json
import threading
import os

# Third Party
import aiofiles


class AsyncStructuredLogger:
Expand Down Expand Up @@ -39,7 +41,7 @@ async def log(self, data):
data["timestamp"] = datetime.now().isoformat()
self.logs.append(data)
await self._write_logs_to_file(data)
{{print(f"\033[92m{json.dumps(data, indent=4)}\033[0m")}}
print(f"\033[92m{json.dumps(data, indent=4)}\033[0m")

async def _write_logs_to_file(self, data):
"""appends to the log instead of writing the whole log each time"""
Expand Down
4 changes: 2 additions & 2 deletions src/instructlab/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@ class LoraOptions(BaseModel):
target_modules: list[str] = Field(
default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"]
)

quantize_data_type: QuantizeDataType = QuantizeDataType.NONE

class Config:
class Config:
use_enum_values = True


Expand Down
20 changes: 12 additions & 8 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# Third Party
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
from deepspeed.runtime.zero.utils import ZeRORuntimeException

# pylint: disable=no-name-in-module
from instructlab.dolomite.hf_models import GPTDolomiteForCausalLM
from torch.distributed import ReduceOp, all_reduce
from tqdm import tqdm
Expand All @@ -20,6 +22,7 @@

# First Party
from instructlab.training import config
from instructlab.training.async_logger import AsyncStructuredLogger
from instructlab.training.config import (
DataProcessArgs,
DeepSpeedOptions,
Expand All @@ -31,12 +34,12 @@
)
from instructlab.training.token_dataset import setup_dataloader, setup_dataset
from instructlab.training.tokenizer_utils import setup_tokenizer
from instructlab.training.async_logger import AsyncStructuredLogger
from instructlab.training.utils import (
StreamablePopen,
add_noisy_embeddings,
apply_gradient_checkpointing,
convert_loss_to_reduce_sum,
ensure_loadable_granite_checkpoint,
patch_target_module,
prepare_peft_model,
prepare_universal_checkpoint_from_latest,
Expand Down Expand Up @@ -92,13 +95,14 @@ def setup_model(args, tokenizer, train_loader, grad_accum):
)

if args.is_granite:
model = GPTDolomiteForCausalLM.from_pretrained(
args.model_name_or_path,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
use_padding_free_transformer=True,
quantization_config=bnb_config,
)
with ensure_loadable_granite_checkpoint(args.model_name_or_path) as path:
model = GPTDolomiteForCausalLM.from_pretrained(
path,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
use_padding_free_transformer=True,
quantization_config=bnb_config,
)
else:
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
Expand Down
4 changes: 3 additions & 1 deletion src/instructlab/training/multipack_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,9 @@ def generate_batches(self, set_stats=False):
# remove indices where the entries are longer than batch max length
indices = indices[self.lengths[indices] <= self.batch_max_length]
if len(indices) < len(self.lengths):
print(f"\033[33mDropping {len(self.lengths) - len(indices)} samples longer than batch_max_length. Ensure that the right max_batch_length is used during data processing.\033[0m")
print(
f"\033[33mDropping {len(self.lengths) - len(indices)} samples longer than batch_max_length. Ensure that the right max_batch_length is used during data processing.\033[0m"
)

lengths = self.lengths[indices]
lengths_cumsum = np.cumsum(lengths)
Expand Down
43 changes: 34 additions & 9 deletions src/instructlab/training/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Standard
from contextlib import contextmanager
from functools import partial
from pathlib import Path
from tempfile import TemporaryDirectory, mktemp
from typing import Any, List, Optional
import importlib
import inspect
Expand All @@ -14,8 +16,14 @@
import warnings

# Third Party
from instructlab.dolomite.hf_models import export_to_huggingface
# pylint: disable=no-name-in-module
from instructlab.dolomite.hf_models import (
GPTDolomiteConfig,
export_to_huggingface,
import_from_huggingface,
)
from rich.logging import RichHandler
from safetensors.torch import save_file
from torch import distributed as dist
from torch.distributed import get_rank, is_initialized
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
Expand All @@ -39,7 +47,7 @@ def retrieve_chat_template(chat_tmpl_path):
spec.loader.exec_module(module)
SPECIAL_TOKENS = module.SPECIAL_TOKENS
CHAT_TEMPLATE = module.CHAT_TEMPLATE
except:
except: # pylint: disable=bare-except
sys.exit(f"Invalid chat template path: {chat_tmpl_path}")
return CHAT_TEMPLATE, SPECIAL_TOKENS

Expand Down Expand Up @@ -473,6 +481,30 @@ class UniversalCheckpointArgs:
log_rank_0(f"Preparing universal checkpoint took {time.time() - start} seconds")


@contextmanager
def ensure_loadable_granite_checkpoint(model_name_or_path: str):
if not dist.is_initialized() or dist.get_rank() == 0:
try:
GPTDolomiteConfig.from_pretrained(model_name_or_path)
yield model_name_or_path
except: # pylint: disable=bare-except
log_rank_0(
f"\033[93mModel saved in {model_name_or_path} requires conversion \033[0m",
to_print=True,
)
# if the load failed then it must not be a granite
# for now just assume its a llama
# with TemporaryDirectory("w") as tmpdir:
# make a temp directory name, but do not create it
tmpdir = mktemp()
import_from_huggingface(model_name_or_path, tmpdir)
yield tmpdir
shutil.rmtree(tmpdir, ignore_errors=True)

if dist.is_initialized():
dist.barrier()


# this function is for supporting gradient checkpointing for padding free
# dolomite
def apply_gradient_checkpointing(
Expand Down Expand Up @@ -608,13 +640,6 @@ def save_hf_format_ds(args, model, tokenizer, samples_seen, convert_granite=True
output_config_file = output_dir / CONFIG_NAME

if args.is_granite and convert_granite:
# guarded import
# Standard
from tempfile import TemporaryDirectory

# Third Party
from safetensors.torch import save_file

with TemporaryDirectory("w") as tmpdir:
save_file(model_state, Path(tmpdir) / WEIGHTS_NAME)
model_to_save.config.to_json_file(Path(tmpdir) / CONFIG_NAME)
Expand Down
Loading