Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch To InstructLab Dolomite Repo #55

Merged
merged 5 commits into from
Jun 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,8 @@ disable=raw-checker-failed,
dangerous-default-value,
consider-using-generator,
broad-exception-caught,
super-init-not-called
super-init-not-called,
duplicate-code

# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ datasets>=2.15.0
numba
numpy
rich
dolomite-engine @ git+https://github.com/ibm-granite/dolomite-engine.git@main
instructlab-dolomite
trl>=0.9.4
peft
pydantic>=2.7.0
Expand Down
10 changes: 2 additions & 8 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# Third Party
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
from deepspeed.runtime.zero.utils import ZeRORuntimeException
from instructlab.dolomite.hf_models import GPTDolomiteForCausalLM
from torch.distributed import ReduceOp, all_reduce
from tqdm import tqdm
from transformers import AutoModelForCausalLM, get_scheduler
Expand All @@ -33,6 +34,7 @@
from instructlab.training.utils import (
StreamablePopen,
add_noisy_embeddings,
apply_gradient_checkpointing,
convert_loss_to_reduce_sum,
patch_target_module,
prepare_peft_model,
Expand Down Expand Up @@ -88,9 +90,6 @@ def setup_model(args, tokenizer, train_loader, grad_accum):
)

if args.is_granite:
# Third Party
from dolomite_engine.hf_models.models import GPTDolomiteForCausalLM

model = GPTDolomiteForCausalLM.from_pretrained(
args.model_name_or_path,
attn_implementation="flash_attention_2",
Expand Down Expand Up @@ -201,14 +200,9 @@ def setup_model(args, tokenizer, train_loader, grad_accum):
# granite gradient checkpointing is handled uniformly
# for both lora and full here
if args.is_granite:
# Third Party
from dolomite_engine.enums import GradientCheckpointingMethod
from dolomite_engine.gradient_checkpointing import apply_gradient_checkpointing

block_name = model._no_split_modules[0]
apply_gradient_checkpointing(
model,
GradientCheckpointingMethod.block,
block_name=block_name,
use_reentrant=True, # this should be the HF default mode
)
Expand Down
62 changes: 61 additions & 1 deletion src/instructlab/training/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Standard
from functools import partial
from pathlib import Path
from typing import Any, List, Optional
import importlib
Expand All @@ -13,9 +14,15 @@
import warnings

# Third Party
from instructlab.dolomite.hf_models import export_to_huggingface
from rich.logging import RichHandler
from torch import distributed as dist
from torch.distributed import get_rank, is_initialized
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointImpl,
apply_activation_checkpointing,
checkpoint_wrapper,
)
from torch.distributed.fsdp import FullStateDictConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType
Expand All @@ -32,7 +39,7 @@
def new_func(x):
if model.training:
embed_init = orig_embed(x)
dims = torch.tensor(torch.numel(x))

Check warning on line 42 in src/instructlab/training/utils.py

View workflow job for this annotation

GitHub Actions / lint

W0702: No exception type(s) specified (bare-except)
mag_norm = noise_alpha / torch.sqrt(dims)
return embed_init + torch.zeros_like(embed_init).uniform_(
-mag_norm, mag_norm
Expand Down Expand Up @@ -453,6 +460,60 @@
log_rank_0(f"Preparing universal checkpoint took {time.time() - start} seconds")


# this function is for supporting gradient checkpointing for padding free
# dolomite
def apply_gradient_checkpointing(
model: torch.nn.Module,
**kwargs,
) -> None:
def get_module_class_from_name(
model: torch.nn.Module, name: str
) -> List[torch.nn.Module]:
modules_children = list(model.children())

if model.__class__.__name__ == name:
return model.__class__
elif len(modules_children) == 0:
return
else:
for child_module in modules_children:
module_class = get_module_class_from_name(child_module, name)
if module_class is not None:
return module_class

def block_checkpointing(
model: torch.nn.Module,
block_name: str,
checkpoint_every: int = 1,
use_reentrant: bool = False,
) -> None:
block_class = get_module_class_from_name(model, block_name)
block_idx = 0

def _whether_to_checkpoint(submodule: torch.nn.Module) -> bool:
nonlocal block_idx

if isinstance(submodule, block_class):
block_idx += 1
if (block_idx - 1) % checkpoint_every == 0:
return True
return False

checkpoint_wrapper_function = checkpoint_wrapper
if use_reentrant:
checkpoint_wrapper_function = partial(
checkpoint_wrapper, checkpoint_impl=CheckpointImpl.REENTRANT
)

apply_activation_checkpointing(
model,
checkpoint_wrapper_fn=checkpoint_wrapper_function,
check_fn=_whether_to_checkpoint,
)

block_checkpointing(model, **kwargs)


def setup_logger(level="DEBUG"):
logging.basicConfig(
level=level, format="%(message)s", datefmt="[%X]", handlers=[RichHandler()]
Expand Down Expand Up @@ -539,7 +600,6 @@
from tempfile import TemporaryDirectory

# Third Party
from dolomite_engine.hf_models import export_to_huggingface
from safetensors.torch import save_file

with TemporaryDirectory("w") as tmpdir:
Expand Down
Loading