Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add transformerengine support #481

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/llama_recipes/configs/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class train_config:
gamma: float= 0.85
seed: int=42
use_fp16: bool=False
use_te: bool=False
mixed_precision: bool=True
val_batch_size: int=1
dataset = "samsum_dataset"
Expand Down
4 changes: 4 additions & 0 deletions src/llama_recipes/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ def setup_wandb(train_config, fsdp_config, **kwargs):
run.config.update(fsdp_config, allow_val_change=True)
return run

def setup_te(train_config):
if train_config.use_te:
from llama_recipes.utils.te_utils import TELlamaForCausalLM as LlamaForCausalLM


def main(**kwargs):
# Update the configuration for the training and sharding process
Expand Down
179 changes: 179 additions & 0 deletions src/llama_recipes/utils/te_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
# This code is adopted from https://github.com/NVIDIA/TransformerEngine/blob/16a469df6bbc77e1c32e48e8e5fd3082dbc2d18e/docs/examples/te_llama/te_llama.py
import os
import re
import gc
from contextlib import contextmanager

import torch
from torch import nn

import transformer_engine as te
from transformer_engine.pytorch.attention import RotaryPositionEmbedding
from transformer_engine.pytorch.fp8 import fp8_model_init

import transformers
from transformers.models.llama.modeling_llama import LlamaModel, LlamaForCausalLM, LlamaRMSNorm, LlamaConfig
from transformers.modeling_utils import _add_variant, load_state_dict, _load_state_dict_into_model
from transformers.utils import WEIGHTS_INDEX_NAME
from transformers.utils.hub import get_checkpoint_shard_files

@contextmanager
def replace_decoder(te_decoder_cls):
"""
Replace `LlamaDecoderLayer` with custom `TELlamaDecoderLayer`.
"""
original_llama_decoder_cls = transformers.models.llama.modeling_llama.LlamaDecoderLayer
transformers.models.llama.modeling_llama.LlamaDecoderLayer = te_decoder_cls
try:
yield
finally:
transformers.models.llama.modeling_llama.LlamaDecoderLayer = original_llama_decoder_cls


class TELlamaDecoderLayer(te.pytorch.TransformerLayer):
"""
Wrapper class over TE's `TransformerLayer`. This makes the wrapper very
similar to HF's `LlamaDecoderLayer` and easier to replace it in the code.

Args:
config: LlamaConfig
args: positional args (for compatibility with `LlamaDecoderLayer`)
kwargs: keyword args (for compatibility with `LlamaDecoderLayer`)
"""
def __init__(self, config, *args, **kwargs):
super().__init__(
hidden_size=config.hidden_size,
ffn_hidden_size=config.intermediate_size,
num_attention_heads=config.num_attention_heads,
bias=False,
layernorm_epsilon=config.rms_norm_eps,
hidden_dropout=0,
attention_dropout=0,
fuse_qkv_params=False,
normalization="RMSNorm",
activation="swiglu",
attn_input_format="bshd",
num_gqa_groups=config.num_key_value_heads
)
te_rope = RotaryPositionEmbedding(config.hidden_size//config.num_attention_heads)
self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda()

def forward(self,
hidden_states,
*args,
attention_mask,
**kwargs):
"""
Custom forward to make sure we only pass relevant arguments to the
forward pass of the `TransformerLayer`. Also, make sure the output
format matches the output of the HF's `LlamaDecoderLayer`.
"""
return (super().forward(hidden_states, attention_mask=attention_mask, rotary_pos_emb=self.te_rope_emb),)


class TELlamaForCausalLM:
"""
Causal LM created with `LlamaModel`. The underlying `LlamaDecoderLayer`
class is monkey-patched with `TELlamaDecoderLayer` class before
initializing the causal LM with `LlamaForCausalLM`.

Args:
config: LlamaConfig
"""

def __new__(cls, config: LlamaConfig):
with replace_decoder(te_decoder_cls=TELlamaDecoderLayer):
llama_for_causal_lm = LlamaForCausalLM(config)
return llama_for_causal_lm

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *args, config, **kwargs):
"""
Custom method adapted from `from_pretrained` method in HuggingFace
Transformers repo: https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579
Checkpoints have to be stored locally in sharded format.
"""
vanilla_model = cls(config).to(kwargs['torch_dtype'])
is_local = os.path.isdir(pretrained_model_name_or_path)
subfolder = ""
variant = None
if os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant))
):
# Load from a sharded PyTorch checkpoint
archive_file = os.path.join(
pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)
)
is_sharded = True
else:
raise AssertionError("Only sharded PyTorch ckpt format supported at the moment")


resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
pretrained_model_name_or_path,
archive_file,
)

# If the checkpoint is not sharded, it's a trivial sharding case
if not is_sharded:
assert not isinstance(resolved_archive_file, list)
resolved_archive_file = [resolved_archive_file]

for shard_file in resolved_archive_file:
state_dict = load_state_dict(shard_file)
# replace_params copies parameters relevant only to TransformerEngine
replace_params(state_dict, vanilla_model.state_dict(), config)
# _load_state_dict_into_model copies parameters other than those in TransformerEngine
_load_state_dict_into_model(vanilla_model, state_dict, start_prefix="")

# Force mem release. Taken from huggingface code
del state_dict
gc.collect()

return vanilla_model

def replace_params(hf_state_dict, te_state_dict, config):
# collect all layer prefixes to update
all_layer_prefixes = set()
for param_key in hf_state_dict.keys():
layer_prefix_pat = 'model.layers.\d+.'
m = re.match(layer_prefix_pat, param_key)
if m is not None:
all_layer_prefixes.add(m.group())



for layer_prefix in all_layer_prefixes:
# When loading weights into models with less number of layers, skip the
# copy if the corresponding layer doesn't exist in HF model
if layer_prefix + 'input_layernorm.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.layer_norm_weight'].data[:] = hf_state_dict[layer_prefix + 'input_layernorm.weight'].data[:]

if layer_prefix + 'self_attn.q_proj.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.query_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.q_proj.weight'].data[:]

if layer_prefix + 'self_attn.k_proj.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.key_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.k_proj.weight'].data[:]

if layer_prefix + 'self_attn.v_proj.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.value_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.v_proj.weight'].data[:]

if layer_prefix + 'self_attn.o_proj.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'self_attention.proj.weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.o_proj.weight'].data[:]

if layer_prefix + 'post_attention_layernorm.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'layernorm_mlp.layer_norm_weight'].data[:] = hf_state_dict[layer_prefix + 'post_attention_layernorm.weight'].data[:]

# It may happen that gate_proj.weight and up_proj.weight will be in the different files, so we need to
# load them separately.
if layer_prefix + 'mlp.gate_proj.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[:config.intermediate_size] = \
hf_state_dict[layer_prefix + 'mlp.gate_proj.weight'].data

if layer_prefix + 'mlp.up_proj.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[config.intermediate_size:] = \
hf_state_dict[layer_prefix + 'mlp.up_proj.weight'].data

if layer_prefix + 'mlp.down_proj.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'layernorm_mlp.fc2_weight'].data[:] = hf_state_dict[layer_prefix + 'mlp.down_proj.weight'].data[:]
return all_layer_prefixes
22 changes: 19 additions & 3 deletions src/llama_recipes/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from pkg_resources import packaging
from datetime import datetime


import torch
import torch.cuda.nccl as nccl
import torch.distributed as dist
Expand All @@ -33,6 +32,17 @@ def set_tokenizer_params(tokenizer: LlamaTokenizer):
def byte2mb(x):
return int(x / 2**20)

def setup_te():
# To avoid https://github.com/NVIDIA/TransformerEngine/issues/115
import transformer_engine
import transformer_engine_extensions
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling
fp8_format = Format.HYBRID # E4M3 during forward pass, E5M2 during backward pass
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")
return te, fp8_recipe


def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None, wandb_run=None):
"""
Trains the model on the given dataloader
Expand All @@ -58,6 +68,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
scaler = torch.cuda.amp.GradScaler()
if train_config.enable_fsdp:
world_size = int(os.environ["WORLD_SIZE"])
if train_config.use_te:
te, fp8_recipe = setup_te()



Expand Down Expand Up @@ -112,8 +124,12 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
batch[key] = batch[key].to('xpu:0')
else:
batch[key] = batch[key].to('cuda:0')
with autocast():
loss = model(**batch).loss
if not train_config.use_te:
with autocast():
loss = model(**batch).loss
else:
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
loss = model(**batch).loss
loss = loss / gradient_accumulation_steps
if train_config.save_metrics:
train_step_loss.append(loss.detach().float().item())
Expand Down