Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gaudi2 support opt-66b(DS mode) #279

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions examples/text-generation/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,6 @@ def write_checkpoints_json(model_name_or_path, local_rank, checkpoints_json):
json.dump(data, fp)


def model_is_bloom(config):
"""
Checks if the given config belongs to a BLOOM-like model.
"""
return config.model_type == "bloom"


def get_optimized_model_name(config):
from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES

Expand Down
8 changes: 3 additions & 5 deletions examples/text-generation/run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from checkpoint_utils import (
get_ds_injection_policy,
get_repo_root,
model_is_bloom,
model_is_optimized,
write_checkpoints_json,
)
Expand Down Expand Up @@ -196,9 +195,8 @@ def main():
if use_deepspeed:
config = AutoConfig.from_pretrained(args.model_name_or_path)
is_optimized = model_is_optimized(config)
is_bloom = model_is_bloom(config)

if is_bloom:
if config.model_type in ["bloom", "opt"]:
# Construct model with fake meta tensors, later will be replaced on devices during ds-inference ckpt load
with deepspeed.OnDevice(dtype=model_dtype, device="meta"):
model = AutoModelForCausalLM.from_config(config, torch_dtype=model_dtype)
Expand All @@ -214,7 +212,7 @@ def main():
ds_inference_kwargs["tensor_parallel"] = {"tp_size": world_size}
ds_inference_kwargs["enable_cuda_graph"] = args.use_hpu_graphs

if is_bloom:
if config.model_type in ["bloom", "opt"]:
# BLOOM is managed differently
checkpoints_json = "checkpoints.json"
write_checkpoints_json(args.model_name_or_path, args.local_rank, checkpoints_json)
Expand All @@ -223,7 +221,7 @@ def main():
torch.distributed.barrier()

ds_inference_kwargs["injection_policy"] = get_ds_injection_policy(config)
if is_bloom:
if config.model_type in ["bloom", "opt"]:
ds_inference_kwargs["checkpoint"] = checkpoints_json

model = deepspeed.init_inference(model, **ds_inference_kwargs)
Expand Down
35 changes: 35 additions & 0 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,38 @@ def adapt_transformers_to_gaudi():
transformers.models.t5.modeling_t5.T5LayerCrossAttention = GaudiT5LayerCrossAttention
transformers.models.t5.modeling_t5.T5DenseActDense = GaudiT5DenseActDense
transformers.models.t5.modeling_t5.T5Attention.forward = gaudi_T5Attention_forward

from transformers.deepspeed import is_deepspeed_available

if is_deepspeed_available():
import deepspeed

from .models.opt.modeling_opt import GaudiEmbeddingLayer, GaudiLinearLayer, GaudiOPTEmbedding

deepspeed.module_inject.layers.LinearLayer = GaudiLinearLayer
deepspeed.module_inject.layers.EmbeddingLayer = GaudiEmbeddingLayer
deepspeed.module_inject.layers.OPTEmbedding = GaudiOPTEmbedding
deepspeed.module_inject.LinearLayer = GaudiLinearLayer
deepspeed.module_inject.EmbeddingLayer = GaudiEmbeddingLayer
deepspeed.module_inject.load_checkpoint.LinearLayer = GaudiLinearLayer
deepspeed.module_inject.replace_module.LinearLayer = GaudiLinearLayer
deepspeed.inference.engine.LinearLayer = GaudiLinearLayer
deepspeed.module_inject.load_checkpoint.EmbeddingLayer = GaudiEmbeddingLayer
deepspeed.module_inject.load_checkpoint.OPTEmbedding = GaudiOPTEmbedding

# For monkey patching to work with DeepSpeed, we need to uncache all DS modules
# so that they are reloaded with updated code at the next import
import sys

to_uncache = []
for mod in sys.modules:
if "deepspeed" in mod:
to_uncache.append(mod)

for mod in to_uncache:
del sys.modules[mod]

# All PyDantic DS class validators have to be cleared so that they are not declared twice
import pydantic

pydantic.class_validators._FUNCS.clear()
76 changes: 76 additions & 0 deletions optimum/habana/transformers/models/opt/modeling_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch
from torch.nn import CrossEntropyLoss
from transformers.deepspeed import is_deepspeed_available
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.opt.modeling_opt import OPTForCausalLM, OPTLearnedPositionalEmbedding, logger

Expand Down Expand Up @@ -525,3 +526,78 @@ def prepare_inputs_for_generation(
}
)
return model_inputs


if is_deepspeed_available():
from torch.nn.parameter import Parameter

class GaudiEmbeddingLayer(torch.nn.Module):
def __init__(self, weight_shape, dtype=torch.bfloat16, device=None):
super().__init__()
self.weight = Parameter(
torch.empty(
weight_shape[0],
weight_shape[1],
dtype=dtype,
device=device,
)
)

def forward(self, input):
return torch.nn.functional.embedding(input, self.weight)

class GaudiLinearLayer(torch.nn.Module):
def __init__(self, weight_shape=None, dtype=torch.bfloat16, weight=None, bias=None, device=None):
super().__init__()
if weight is not None:
self.weight = weight
self.bias = bias
else:
self.weight = Parameter(
torch.empty(
weight_shape,
dtype=dtype,
device=device,
)
)
self.bias = Parameter(
torch.empty(
weight_shape[0],
dtype=dtype,
device=device,
)
)

def forward(self, input):
output = torch.matmul(input, self.weight.transpose(-1, -2))
if self.bias is not None:
output += self.bias
return output

class GaudiOPTEmbedding(GaudiEmbeddingLayer):
"""
Adapted from deepspeed.module_inject.layers.OPTEmbedding: https://github.com/HabanaAI/DeepSpeed/blob/410b3dbb74c7d266eba71aadf26ee616c2882040/deepspeed/module_inject/layers.py#L74
Differences are:
- add `token_idx` to manage static shapes
"""

def __init__(self, weight_shape, device=None):
# OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models don't have this hack
self.offset = 2
super().__init__(weight_shape, device=device)

def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0, token_idx=None):
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
attention_mask = attention_mask.long()

if past_key_values_length == 0:
# create positions depending on attention_mask
positions = (torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask).long() - 1

# cut positions if `past_key_values_length` is > 0
positions = positions[:, past_key_values_length:]

return super().forward(positions + self.offset)
else:
return super().forward(token_idx + self.offset)
Loading