From a2761fde8b1eee79fc58a05dd59b4107d236eeca Mon Sep 17 00:00:00 2001 From: Harish Subramony Date: Fri, 4 Aug 2023 00:50:12 +0300 Subject: [PATCH 1/3] revert partially Enable/Optimize flan t5 xxl on deepspeed z3 #257 --- optimum/habana/transformers/modeling_utils.py | 20 +- .../habana/transformers/models/__init__.py | 9 - .../habana/transformers/models/t5/__init__.py | 9 - .../transformers/models/t5/modeling_t5.py | 215 ------------------ 4 files changed, 3 insertions(+), 250 deletions(-) delete mode 100644 optimum/habana/transformers/models/t5/__init__.py delete mode 100644 optimum/habana/transformers/models/t5/modeling_t5.py diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index 4a40aab015..eece652763 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -26,12 +26,6 @@ GaudiLlamaForCausalLM, GaudiOPTForCausalLM, GaudiOPTLearnedPositionalEmbedding, - GaudiT5DenseActDense, - GaudiT5DenseGatedActDense, - GaudiT5LayerCrossAttention, - GaudiT5LayerFF, - GaudiT5LayerSelfAttention, - GaudiT5Stack, _gaudi_esmfold_attention_wrap_up, gaudi_albert_forward, gaudi_bloom_attention_forward, @@ -64,7 +58,6 @@ gaudi_opt_model_forward, gaudi_rot_matmul, gaudi_rot_vec_mul, - gaudi_T5Attention_forward, gaudi_vit_self_attention_forward, gaudi_wav2vec2_forward, ) @@ -74,8 +67,10 @@ def adapt_transformers_to_gaudi(): """ Replaces some Transformers' methods for equivalent methods optimized for Gaudi. - """ + Args: + use_habana_mixed_precision (bool): whether HMP is used or not. + """ # optimize Conv1D transformers.pytorch_utils.Conv1D.forward = gaudi_conv1d_forward @@ -168,12 +163,3 @@ def adapt_transformers_to_gaudi(): transformers.models.llama.modeling_llama.LlamaModel.forward = gaudi_llama_model_forward transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = gaudi_llama_decoder_layer_forward transformers.models.llama.modeling_llama.LlamaAttention.forward = gaudi_llama_attention_forward - - # Dropout kernel improvement for Flan-T5 - transformers.models.t5.modeling_t5.T5Stack = GaudiT5Stack - transformers.models.t5.modeling_t5.T5DenseGatedActDense = GaudiT5DenseGatedActDense - transformers.models.t5.modeling_t5.T5LayerFF = GaudiT5LayerFF - transformers.models.t5.modeling_t5.T5LayerSelfAttention = GaudiT5LayerSelfAttention - transformers.models.t5.modeling_t5.T5LayerCrossAttention = GaudiT5LayerCrossAttention - transformers.models.t5.modeling_t5.T5DenseActDense = GaudiT5DenseActDense - transformers.models.t5.modeling_t5.T5Attention.forward = gaudi_T5Attention_forward diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index e11e091e67..64f978427c 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -46,15 +46,6 @@ gaudi_opt_decoder_layer_forward, gaudi_opt_model_forward, ) -from .t5 import ( - GaudiT5DenseActDense, - GaudiT5DenseGatedActDense, - GaudiT5LayerCrossAttention, - GaudiT5LayerFF, - GaudiT5LayerSelfAttention, - GaudiT5Stack, - gaudi_T5Attention_forward, -) from .vit import gaudi_vit_self_attention_forward from .wav2vec2 import ( _gaudi_wav2vec2_compute_mask_indices, diff --git a/optimum/habana/transformers/models/t5/__init__.py b/optimum/habana/transformers/models/t5/__init__.py deleted file mode 100644 index d659eb4903..0000000000 --- a/optimum/habana/transformers/models/t5/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from .modeling_t5 import ( - GaudiT5DenseActDense, - GaudiT5DenseGatedActDense, - GaudiT5LayerCrossAttention, - GaudiT5LayerFF, - GaudiT5LayerSelfAttention, - GaudiT5Stack, - gaudi_T5Attention_forward, -) diff --git a/optimum/habana/transformers/models/t5/modeling_t5.py b/optimum/habana/transformers/models/t5/modeling_t5.py deleted file mode 100644 index 3ec507901e..0000000000 --- a/optimum/habana/transformers/models/t5/modeling_t5.py +++ /dev/null @@ -1,215 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The HuggingFace Inc. team. -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import habana_frameworks.torch.core as htcore -import torch -import torch.nn as nn -from transformers.models.t5.modeling_t5 import ( - T5DenseActDense, - T5DenseGatedActDense, - T5LayerCrossAttention, - T5LayerFF, - T5LayerSelfAttention, - T5Stack, -) - - -def gaudi_T5Attention_forward( - self, - hidden_states, - mask=None, - key_value_states=None, - position_bias=None, - past_key_value=None, - layer_head_mask=None, - query_length=None, - use_cache=False, - output_attentions=False, -): - """ - Copied from T5Attention: https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L452 - The only difference is: wrap `nn.functional.dropout` with `mark_step` for numerical improvement. - """ - # Input is (batch_size, seq_length, dim) - # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) - # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) - batch_size, seq_length = hidden_states.shape[:2] - real_seq_length = seq_length - - if past_key_value is not None: - if len(past_key_value) != 2: - raise ValueError( - f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" - ) - real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length - - key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] - - def shape(states): - """projection""" - return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - - def unshape(states): - """reshape""" - return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) - - def project(hidden_states, proj_layer, key_value_states, past_key_value): - """projects hidden states correctly to key/query states""" - if key_value_states is None: - # self-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(hidden_states)) - elif past_key_value is None: - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - - if past_key_value is not None: - if key_value_states is None: - # self-attn - # (batch_size, n_heads, key_length, dim_per_head) - hidden_states = torch.cat([past_key_value, hidden_states], dim=2) - elif past_key_value.shape[2] != key_value_states.shape[1]: - # checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - else: - # cross-attn - hidden_states = past_key_value - return hidden_states - - # get query states - query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) - - # get key/value states - key_states = project( - hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None - ) - value_states = project( - hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None - ) - - # compute scores - scores = torch.matmul( - query_states, key_states.transpose(3, 2) - ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 - - if position_bias is None: - if not self.has_relative_attention_bias: - position_bias = torch.zeros( - (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype - ) - if self.gradient_checkpointing and self.training: - position_bias.requires_grad = True - else: - position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) - - # if key and values are already calculated - # we want only the last query position bias - if past_key_value is not None: - position_bias = position_bias[:, :, -hidden_states.size(1) :, :] - - if mask is not None: - position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) - - if self.pruned_heads: - mask = torch.ones(position_bias.shape[1]) - mask[list(self.pruned_heads)] = 0 - position_bias_masked = position_bias[:, mask.bool()] - else: - position_bias_masked = position_bias - - scores += position_bias_masked - attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( - scores - ) # (batch_size, n_heads, seq_length, key_length) - if self.training: - htcore.mark_step() - attn_weights = nn.functional.dropout( - attn_weights, p=self.dropout, training=self.training - ) # (batch_size, n_heads, seq_length, key_length) - if self.training: - htcore.mark_step() - - # Mask heads if we want to - if layer_head_mask is not None: - attn_weights = attn_weights * layer_head_mask - - attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) - attn_output = self.o(attn_output) - - present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None - outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) - - if output_attentions: - outputs = outputs + (attn_weights,) - return outputs - - -class GaudiDropout(torch.nn.Module): - def __init__(self, dropout): - super().__init__() - self.dropout = dropout - - def forward(self, x): - """ - Avoids dropout kernel fusion with others to stablize training numerical stability. - """ - if self.dropout.training: - htcore.mark_step() - out = self.dropout(x) - if self.dropout.training: - htcore.mark_step() - return out - - -class GaudiT5DenseActDense(T5DenseActDense): - def __init__(self, config): - super().__init__(config) - self.dropout = GaudiDropout(self.dropout) - - -class GaudiT5DenseGatedActDense(T5DenseGatedActDense): - def __init__(self, config): - super().__init__(config) - self.dropout = GaudiDropout(self.dropout) - - -class GaudiT5LayerFF(T5LayerFF): - def __init__(self, config): - super().__init__(config) - self.dropout = GaudiDropout(self.dropout) - - -class GaudiT5LayerSelfAttention(T5LayerSelfAttention): - def __init__(self, config, has_relative_attention_bias=False): - super().__init__(config, has_relative_attention_bias) - self.dropout = GaudiDropout(self.dropout) - - -class GaudiT5LayerCrossAttention(T5LayerCrossAttention): - def __init__(self, config): - super().__init__(config) - self.dropout = GaudiDropout(self.dropout) - - -class GaudiT5Stack(T5Stack): - def __init__(self, config, embed_tokens=None): - super().__init__(config, embed_tokens) - self.dropout = GaudiDropout(self.dropout) From 70bed737aaabb4a7898cd2e5041d1d9fdaacd20e Mon Sep 17 00:00:00 2001 From: Harish Subramony Date: Fri, 4 Aug 2023 02:25:00 +0300 Subject: [PATCH 2/3] remove unwanted comments --- optimum/habana/transformers/modeling_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index eece652763..899ae22016 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -67,10 +67,8 @@ def adapt_transformers_to_gaudi(): """ Replaces some Transformers' methods for equivalent methods optimized for Gaudi. - - Args: - use_habana_mixed_precision (bool): whether HMP is used or not. """ + # optimize Conv1D transformers.pytorch_utils.Conv1D.forward = gaudi_conv1d_forward From ea8ad1cd400308bb93acf96e0a69adedb1c9d5d6 Mon Sep 17 00:00:00 2001 From: Harish Subramony Date: Wed, 9 Aug 2023 10:04:49 -0700 Subject: [PATCH 3/3] change contiguous_gradients to false for ds_flan_t5_z3_config_bf16 --- examples/summarization/ds_flan_t5_z3_config_bf16.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/summarization/ds_flan_t5_z3_config_bf16.json b/examples/summarization/ds_flan_t5_z3_config_bf16.json index dfa0fa77b5..b4202acebb 100644 --- a/examples/summarization/ds_flan_t5_z3_config_bf16.json +++ b/examples/summarization/ds_flan_t5_z3_config_bf16.json @@ -24,7 +24,7 @@ "zero_optimization": { "stage": 3, "overlap_comm": true, - "contiguous_gradients": true, + "contiguous_gradients": false, "sub_group_size": 1e9, "reduce_bucket_size": "auto", "stage3_prefetch_bucket_size": "auto",