From 4520bf11a9c54f6bc6ab0eca89944b63eea27374 Mon Sep 17 00:00:00 2001 From: zzhai Date: Sun, 25 Jun 2023 15:36:07 +0800 Subject: [PATCH 1/5] gaudi2 support opt-66b --- examples/text-generation/run_generation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index 65dce6d613..1bc7f647c0 100644 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -180,7 +180,7 @@ def main(): with deepspeed.OnDevice(dtype=model_dtype, device="meta"): model = AutoModelForCausalLM.from_config(config, torch_dtype=model_dtype) else: - with deepspeed.OnDevice(dtype=model_dtype, device=args.device): + with deepspeed.OnDevice(dtype=model_dtype, device="cpu"): model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, torch_dtype=model_dtype) model = model.eval() From dd095c91dc3a154e9303d065993c37184e8da796 Mon Sep 17 00:00:00 2001 From: regisss <15324346+regisss@users.noreply.github.com> Date: Fri, 30 Jun 2023 17:52:48 +0000 Subject: [PATCH 2/5] Adapt DeepSpeed module injection for OPT --- examples/text-generation/checkpoint_utils.py | 7 --- examples/text-generation/run_generation.py | 12 ++-- optimum/habana/transformers/modeling_utils.py | 11 ++++ .../transformers/models/opt/modeling_opt.py | 62 +++++++++++++++++++ 4 files changed, 78 insertions(+), 14 deletions(-) diff --git a/examples/text-generation/checkpoint_utils.py b/examples/text-generation/checkpoint_utils.py index b4f48d3cd8..1e2531f0e9 100644 --- a/examples/text-generation/checkpoint_utils.py +++ b/examples/text-generation/checkpoint_utils.py @@ -59,13 +59,6 @@ def write_checkpoints_json(model_name_or_path, local_rank, checkpoints_json): json.dump(data, fp) -def model_is_bloom(config): - """ - Checks if the given config belongs to a BLOOM-like model. - """ - return config.model_type == "bloom" - - def get_optimized_model_name(config): model_names = ["bloom", "gpt2", "opt", "gptj", "gpt_neox"] for model_name in model_names: diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index 1bc7f647c0..8e230b8fd9 100644 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -25,7 +25,7 @@ import torch import torch.nn.functional as F -from checkpoint_utils import get_ds_injection_policy, model_is_bloom, model_is_optimized, write_checkpoints_json +from checkpoint_utils import get_ds_injection_policy, model_is_optimized, write_checkpoints_json from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from transformers.generation import GenerationConfig @@ -173,14 +173,13 @@ def main(): if use_deepspeed: config = AutoConfig.from_pretrained(args.model_name_or_path) is_optimized = model_is_optimized(config) - is_bloom = model_is_bloom(config) - if is_bloom: + if config.model_type in ["bloom", "opt"]: # Construct model with fake meta tensors, later will be replaced on devices during ds-inference ckpt load with deepspeed.OnDevice(dtype=model_dtype, device="meta"): model = AutoModelForCausalLM.from_config(config, torch_dtype=model_dtype) else: - with deepspeed.OnDevice(dtype=model_dtype, device="cpu"): + with deepspeed.OnDevice(dtype=model_dtype, device=args.device): model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, torch_dtype=model_dtype) model = model.eval() @@ -189,7 +188,7 @@ def main(): ds_inference_kwargs["tensor_parallel"] = {"tp_size": world_size} ds_inference_kwargs["enable_cuda_graph"] = args.use_hpu_graphs - if is_bloom: + if config.model_type in ["bloom", "opt"]: # BLOOM is managed differently checkpoints_json = "checkpoints.json" write_checkpoints_json(args.model_name_or_path, args.local_rank, checkpoints_json) @@ -198,7 +197,7 @@ def main(): torch.distributed.barrier() ds_inference_kwargs["injection_policy"] = get_ds_injection_policy(config) - if is_bloom: + if config.model_type in ["bloom", "opt"]: ds_inference_kwargs["checkpoint"] = checkpoints_json model = deepspeed.init_inference(model, **ds_inference_kwargs) @@ -206,7 +205,6 @@ def main(): else: model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, torch_dtype=model_dtype) model = model.eval().to(args.device) - is_bloom = model_is_bloom(model.config) is_optimized = model_is_optimized(model.config) if args.use_hpu_graphs: diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index 8f9ba9101e..6bc55e1014 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -175,3 +175,14 @@ def adapt_transformers_to_gaudi(): modeling_t5.T5LayerCrossAttention = GaudiT5LayerCrossAttention modeling_t5.T5DenseActDense = GaudiT5DenseActDense modeling_t5.T5Attention.forward = gaudi_T5Attention_forward + + from transformers.deepspeed import is_deepspeed_available + + if is_deepspeed_available(): + import deepspeed + + from .models.opt.modeling_opt import GaudiEmbeddingLayer, GaudiLinearLayer, GaudiOPTEmbedding + + deepspeed.module_inject.layers.EmbeddingLayer = GaudiEmbeddingLayer + deepspeed.module_inject.layers.LinearLayer = GaudiLinearLayer + deepspeed.module_inject.layers.OPTEmbedding = GaudiOPTEmbedding diff --git a/optimum/habana/transformers/models/opt/modeling_opt.py b/optimum/habana/transformers/models/opt/modeling_opt.py index c8cfa9ffa0..b20fbc28f6 100644 --- a/optimum/habana/transformers/models/opt/modeling_opt.py +++ b/optimum/habana/transformers/models/opt/modeling_opt.py @@ -3,6 +3,7 @@ import torch from torch.nn import CrossEntropyLoss +from transformers.deepspeed import is_deepspeed_available from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.models.opt.modeling_opt import OPTForCausalLM, OPTLearnedPositionalEmbedding, logger @@ -525,3 +526,64 @@ def prepare_inputs_for_generation( } ) return model_inputs + + +if is_deepspeed_available(): + from deepspeed.module_inject.layers import EmbeddingLayer, LinearLayer, OPTEmbedding + from torch.nn.parameter import Parameter + + class GaudiEmbeddingLayer(EmbeddingLayer): + def __init__(self, weight_shape, dtype=torch.bfloat16, device=None): + super(EmbeddingLayer, self).__init__() + self.weight = Parameter( + torch.empty( + weight_shape[0], + weight_shape[1], + dtype=dtype, + device=device, + ) + ) + + class GaudiLinearLayer(LinearLayer): + def __init__(self, weight_shape=None, dtype=torch.bfloat16, weight=None, bias=None, device=None): + super(LinearLayer, self).__init__() + if weight is not None: + self.weight = weight + self.bias = bias + else: + self.weight = Parameter( + torch.empty( + weight_shape, + dtype=dtype, + device=device, + ) + ) + self.bias = Parameter( + torch.empty( + weight_shape[0], + dtype=dtype, + device=device, + ) + ) + + class GaudiOPTEmbedding(OPTEmbedding): + """ + Adapted from deepspeed.module_inject.layers.OPTEmbedding: https://github.com/HabanaAI/DeepSpeed/blob/410b3dbb74c7d266eba71aadf26ee616c2882040/deepspeed/module_inject/layers.py#L74 + Differences are: + - add `token_idx` to manage static shapes + """ + + def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0, token_idx=None): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + attention_mask = attention_mask.long() + + if past_key_values_length == 0: + # create positions depending on attention_mask + positions = (torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask).long() - 1 + + # cut positions if `past_key_values_length` is > 0 + positions = positions[:, past_key_values_length:] + + return super().forward(positions + self.offset) + else: + return super().forward(token_idx + self.offset) From a5f1cc78c2e1c1ced428216920475ec79da9d760 Mon Sep 17 00:00:00 2001 From: regisss <15324346+regisss@users.noreply.github.com> Date: Fri, 30 Jun 2023 18:17:42 +0000 Subject: [PATCH 3/5] Fix super --- .../transformers/models/opt/modeling_opt.py | 39 ++++++------------- 1 file changed, 11 insertions(+), 28 deletions(-) diff --git a/optimum/habana/transformers/models/opt/modeling_opt.py b/optimum/habana/transformers/models/opt/modeling_opt.py index b20fbc28f6..38fcd2b86a 100644 --- a/optimum/habana/transformers/models/opt/modeling_opt.py +++ b/optimum/habana/transformers/models/opt/modeling_opt.py @@ -530,41 +530,24 @@ def prepare_inputs_for_generation( if is_deepspeed_available(): from deepspeed.module_inject.layers import EmbeddingLayer, LinearLayer, OPTEmbedding - from torch.nn.parameter import Parameter class GaudiEmbeddingLayer(EmbeddingLayer): def __init__(self, weight_shape, dtype=torch.bfloat16, device=None): - super(EmbeddingLayer, self).__init__() - self.weight = Parameter( - torch.empty( - weight_shape[0], - weight_shape[1], - dtype=dtype, - device=device, - ) + super().__init__( + weight_shape, + dtype=torch.bfloat16, + device=None, ) class GaudiLinearLayer(LinearLayer): def __init__(self, weight_shape=None, dtype=torch.bfloat16, weight=None, bias=None, device=None): - super(LinearLayer, self).__init__() - if weight is not None: - self.weight = weight - self.bias = bias - else: - self.weight = Parameter( - torch.empty( - weight_shape, - dtype=dtype, - device=device, - ) - ) - self.bias = Parameter( - torch.empty( - weight_shape[0], - dtype=dtype, - device=device, - ) - ) + super().__init__( + weight_shape, + dtype, + weight, + bias, + device, + ) class GaudiOPTEmbedding(OPTEmbedding): """ From ccea6f802b536601bd0906cd079694ca30b88d7d Mon Sep 17 00:00:00 2001 From: regisss <15324346+regisss@users.noreply.github.com> Date: Fri, 30 Jun 2023 18:29:29 +0000 Subject: [PATCH 4/5] Fix --- examples/text-generation/run_generation.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index 939a5a9147..02aa988f1a 100644 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -151,6 +151,11 @@ def main(): world_size, rank, args.local_rank = initialize_distributed_hpu() + # Tweak generation so that it runs faster on Gaudi + from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi + + adapt_transformers_to_gaudi() + if use_deepspeed: # Check if DeepSpeed is installed from transformers.deepspeed import is_deepspeed_available @@ -168,11 +173,6 @@ def main(): else: logger.info("Single-device run.") - # Tweak generation so that it runs faster on Gaudi - from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi - - adapt_transformers_to_gaudi() - # Set seed before initializing model. from optimum.habana.utils import set_seed From 0032f86d4eb2ef4f89df08972d056e15f97cdde1 Mon Sep 17 00:00:00 2001 From: regisss <15324346+regisss@users.noreply.github.com> Date: Mon, 3 Jul 2023 09:35:53 +0000 Subject: [PATCH 5/5] Override DS OPT layers --- examples/text-generation/run_generation.py | 10 +-- optimum/habana/transformers/modeling_utils.py | 26 +++++++- .../transformers/models/opt/modeling_opt.py | 61 ++++++++++++++----- 3 files changed, 76 insertions(+), 21 deletions(-) diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index 02aa988f1a..939a5a9147 100644 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -151,11 +151,6 @@ def main(): world_size, rank, args.local_rank = initialize_distributed_hpu() - # Tweak generation so that it runs faster on Gaudi - from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi - - adapt_transformers_to_gaudi() - if use_deepspeed: # Check if DeepSpeed is installed from transformers.deepspeed import is_deepspeed_available @@ -173,6 +168,11 @@ def main(): else: logger.info("Single-device run.") + # Tweak generation so that it runs faster on Gaudi + from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi + + adapt_transformers_to_gaudi() + # Set seed before initializing model. from optimum.habana.utils import set_seed diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index 6bc55e1014..104b04e6ed 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -183,6 +183,30 @@ def adapt_transformers_to_gaudi(): from .models.opt.modeling_opt import GaudiEmbeddingLayer, GaudiLinearLayer, GaudiOPTEmbedding - deepspeed.module_inject.layers.EmbeddingLayer = GaudiEmbeddingLayer deepspeed.module_inject.layers.LinearLayer = GaudiLinearLayer + deepspeed.module_inject.layers.EmbeddingLayer = GaudiEmbeddingLayer deepspeed.module_inject.layers.OPTEmbedding = GaudiOPTEmbedding + deepspeed.module_inject.LinearLayer = GaudiLinearLayer + deepspeed.module_inject.EmbeddingLayer = GaudiEmbeddingLayer + deepspeed.module_inject.load_checkpoint.LinearLayer = GaudiLinearLayer + deepspeed.module_inject.replace_module.LinearLayer = GaudiLinearLayer + deepspeed.inference.engine.LinearLayer = GaudiLinearLayer + deepspeed.module_inject.load_checkpoint.EmbeddingLayer = GaudiEmbeddingLayer + deepspeed.module_inject.load_checkpoint.OPTEmbedding = GaudiOPTEmbedding + + # For monkey patching to work with DeepSpeed, we need to uncache all DS modules + # so that they are reloaded with updated code at the next import + import sys + + to_uncache = [] + for mod in sys.modules: + if "deepspeed" in mod: + to_uncache.append(mod) + + for mod in to_uncache: + del sys.modules[mod] + + # All PyDantic DS class validators have to be cleared so that they are not declared twice + import pydantic + + pydantic.class_validators._FUNCS.clear() diff --git a/optimum/habana/transformers/models/opt/modeling_opt.py b/optimum/habana/transformers/models/opt/modeling_opt.py index 38fcd2b86a..600eaaca0f 100644 --- a/optimum/habana/transformers/models/opt/modeling_opt.py +++ b/optimum/habana/transformers/models/opt/modeling_opt.py @@ -529,33 +529,64 @@ def prepare_inputs_for_generation( if is_deepspeed_available(): - from deepspeed.module_inject.layers import EmbeddingLayer, LinearLayer, OPTEmbedding + from torch.nn.parameter import Parameter - class GaudiEmbeddingLayer(EmbeddingLayer): + class GaudiEmbeddingLayer(torch.nn.Module): def __init__(self, weight_shape, dtype=torch.bfloat16, device=None): - super().__init__( - weight_shape, - dtype=torch.bfloat16, - device=None, + super().__init__() + self.weight = Parameter( + torch.empty( + weight_shape[0], + weight_shape[1], + dtype=dtype, + device=device, + ) ) - class GaudiLinearLayer(LinearLayer): + def forward(self, input): + return torch.nn.functional.embedding(input, self.weight) + + class GaudiLinearLayer(torch.nn.Module): def __init__(self, weight_shape=None, dtype=torch.bfloat16, weight=None, bias=None, device=None): - super().__init__( - weight_shape, - dtype, - weight, - bias, - device, - ) + super().__init__() + if weight is not None: + self.weight = weight + self.bias = bias + else: + self.weight = Parameter( + torch.empty( + weight_shape, + dtype=dtype, + device=device, + ) + ) + self.bias = Parameter( + torch.empty( + weight_shape[0], + dtype=dtype, + device=device, + ) + ) - class GaudiOPTEmbedding(OPTEmbedding): + def forward(self, input): + output = torch.matmul(input, self.weight.transpose(-1, -2)) + if self.bias is not None: + output += self.bias + return output + + class GaudiOPTEmbedding(GaudiEmbeddingLayer): """ Adapted from deepspeed.module_inject.layers.OPTEmbedding: https://github.com/HabanaAI/DeepSpeed/blob/410b3dbb74c7d266eba71aadf26ee616c2882040/deepspeed/module_inject/layers.py#L74 Differences are: - add `token_idx` to manage static shapes """ + def __init__(self, weight_shape, device=None): + # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(weight_shape, device=device) + def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0, token_idx=None): """`input_ids_shape` is expected to be [bsz x seqlen].""" attention_mask = attention_mask.long()