diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index 2b7bb32bce..d7f98f8376 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -76,6 +76,8 @@ GaudiMixtralDecoderLayer, GaudiMixtralForCausalLM, GaudiMixtralModel, + GaudiMptAttention, + GaudiMptBlock, GaudiMptForCausalLM, GaudiMptModel, GaudiOPTForCausalLM, @@ -152,8 +154,6 @@ gaudi_mistral_rmsnorm_forward, gaudi_mixtral_block_sparse_moe_forward, gaudi_mixtral_rmsnorm_forward, - gaudi_mpt_attention_forward, - gaudi_mpt_block_forward, gaudi_opt_attention_forward, gaudi_opt_decoder_forward, gaudi_opt_decoder_layer_forward, @@ -420,8 +420,8 @@ def adapt_transformers_to_gaudi(): # Optimization for mpt on Gaudi transformers.models.mpt.modeling_mpt.MptForCausalLM = GaudiMptForCausalLM transformers.models.mpt.modeling_mpt.MptModel = GaudiMptModel - transformers.models.mpt.modeling_mpt.MptAttention.forward = gaudi_mpt_attention_forward - transformers.models.mpt.modeling_mpt.MptBlock.forward = gaudi_mpt_block_forward + transformers.models.mpt.modeling_mpt.MptAttention = GaudiMptAttention + transformers.models.mpt.modeling_mpt.MptBlock = GaudiMptBlock # Optimization for mistral on Gaudi transformers.models.mistral.modeling_mistral.MistralForCausalLM = GaudiMistralForCausalLM diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index 99ef65c4e4..5a4861fbdf 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -138,10 +138,10 @@ gaudi_invert_attention_mask, ) from .mpt import ( + GaudiMptAttention, + GaudiMptBlock, GaudiMptForCausalLM, GaudiMptModel, - gaudi_mpt_attention_forward, - gaudi_mpt_block_forward, ) from .opt import ( GaudiOPTForCausalLM, diff --git a/optimum/habana/transformers/models/mpt/__init__.py b/optimum/habana/transformers/models/mpt/__init__.py index 1ab41c1a80..351152c026 100644 --- a/optimum/habana/transformers/models/mpt/__init__.py +++ b/optimum/habana/transformers/models/mpt/__init__.py @@ -1,6 +1,6 @@ from .modeling_mpt import ( + GaudiMptAttention, + GaudiMptBlock, GaudiMptForCausalLM, GaudiMptModel, - gaudi_mpt_attention_forward, - gaudi_mpt_block_forward, ) diff --git a/optimum/habana/transformers/models/mpt/modeling_mpt.py b/optimum/habana/transformers/models/mpt/modeling_mpt.py index 7cefc4e37f..369bae9234 100755 --- a/optimum/habana/transformers/models/mpt/modeling_mpt.py +++ b/optimum/habana/transformers/models/mpt/modeling_mpt.py @@ -21,14 +21,18 @@ from torch import nn from torch.nn import CrossEntropyLoss from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions -from transformers.models.mpt.modeling_mpt import MptForCausalLM, MptModel +from transformers.models.mpt.modeling_mpt import ( + MptAttention, + MptBlock, + MptConfig, + MptForCausalLM, + MptModel, +) from transformers.utils import logging from ...modeling_attn_mask_utils import _gaudi_prepare_4d_causal_attention_mask -logger = logging.get_logger(__name__) - try: from habana_frameworks.torch.hpex.kernels import FusedSDPA except ImportError: @@ -36,159 +40,178 @@ FusedSDPA = None -def gaudi_mpt_attention_forward( - self, - hidden_states: torch.Tensor, - position_bias: torch.Tensor, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - token_idx: Optional[torch.Tensor] = None, - use_flash_attention: Optional[bool] = False, - flash_attention_recompute: Optional[bool] = False, -): - """ - Copied from MptAttention.forward: https://github.com/huggingface/transformers/blob/v4.32.0/src/transformers/models/mpt/modeling_mpt.py - The only differences are: - - add new args token_idx - - optimize KV cache - - add new args use_flash_attention - - add new arg flash_attention_recompute - """ - - batch_size, seq_length = hidden_states.shape[:2] - - mixed_qkv = self.Wqkv(hidden_states) - if self.clip_qkv: - mixed_qkv = mixed_qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv) - - bs, seq_len, three_times_hidden_size = mixed_qkv.shape - mixed_qkv = mixed_qkv.view(bs, seq_len, self.n_heads * 3, self.head_dim) - mixed_qkv = mixed_qkv.transpose(1, 2) - query_states, key_states, value_states = ( - mixed_qkv[:, : self.n_heads, ...], - mixed_qkv[:, self.n_heads : 2 * self.n_heads, ...], - mixed_qkv[:, 2 * self.n_heads :, ...], - ) - - if past_key_value is not None: - if len(past_key_value) != 0: - if token_idx is not None: - past_key_value[0].index_copy_(2, token_idx - 1, key_states) - past_key_value[1].index_copy_(2, token_idx - 1, value_states) - key_states = past_key_value[0] - value_states = past_key_value[1] - else: - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - past_key_value = [key_states, value_states] - else: - past_key_value = [ - torch.empty(key_states.shape, dtype=key_states.dtype, device=key_states.device), - torch.empty(key_states.shape, dtype=key_states.dtype, device=key_states.device), - ] - past_key_value[0][:] = key_states[:] - past_key_value[1][:] = value_states[:] - - query_length = seq_length if past_key_value is None else seq_length + past_key_value[0].shape[2] - - if position_bias is not None: - if len(position_bias.shape) != 3: - raise ValueError(f"Expecting position_bias shape to be 3 dimensions, got {len(position_bias.shape)}") - key_length = key_states.shape[-2] - - position_bias_query_index = max(0, position_bias.size(1) - query_length) - position_bias_key_index = max(0, position_bias.size(2) - key_length) - - position_bias = position_bias[:, position_bias_query_index:, position_bias_key_index:] - - if use_flash_attention and FusedSDPA: - import habana_frameworks.torch.hpu as ht - - with ht.sdp_kernel(enable_recompute=flash_attention_recompute): - attn_output = FusedSDPA.apply( - query_states, - key_states, - value_states, - attention_mask * torch.finfo(query_states.dtype).min + position_bias.to(query_states.dtype), - 0.0, - False, - None, - ) - attn_weights = None - else: - attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * self.softmax_scale +logger = logging.get_logger(__name__) + + +class Softmax(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, dim=None, invAttnHead=None): + return torch.nn.functional.softmax(x, dim) + + +class GaudiMptAttention(MptAttention): + def __init__(self, config: MptConfig): + super().__init__(config) + + self.softmax = Softmax() + + def forward( + self, + hidden_states: torch.Tensor, + position_bias: torch.Tensor, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + token_idx: Optional[torch.Tensor] = None, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + ): + """ + Copied from MptAttention.forward: https://github.com/huggingface/transformers/blob/v4.44.1/src/transformers/models/mpt/modeling_mpt.py + The only differences are: + - add new args token_idx + - optimize KV cache + - add new args use_flash_attention + - add new arg flash_attention_recompute + """ + + batch_size, seq_length = hidden_states.shape[:2] + + mixed_qkv = self.Wqkv(hidden_states) + if self.clip_qkv: + mixed_qkv = mixed_qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv) + + query_states, key_states, value_states = mixed_qkv.chunk(3, dim=2) + query_states = query_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2) + key_states = key_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2) + value_states = value_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + if len(past_key_value) != 0: + if token_idx is not None: + past_key_value[0].index_copy_(2, token_idx - 1, key_states) + past_key_value[1].index_copy_(2, token_idx - 1, value_states) + key_states = past_key_value[0] + value_states = past_key_value[1] + else: + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + past_key_value = [key_states, value_states] + else: + past_key_value = [ + torch.empty(key_states.shape, dtype=key_states.dtype, device=key_states.device), + torch.empty(key_states.shape, dtype=key_states.dtype, device=key_states.device), + ] + past_key_value[0][:] = key_states[:] + past_key_value[1][:] = value_states[:] + + query_length = seq_length if past_key_value is None else seq_length + past_key_value[0].shape[2] if position_bias is not None: - attention_scores = attention_scores + position_bias - if attention_mask is not None: - attention_scores = attention_scores.masked_fill(attention_mask, torch.finfo(query_states.dtype).min) - - # (batch_size, n_heads, seq_length, key_length) - attn_weights = nn.functional.softmax(attention_scores.float(), dim=-1).to(value_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attn_dropout_p, training=self.training) - - attn_output = torch.matmul(attn_weights, value_states) - - attn_output = attn_output.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_length, -1) - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights, past_key_value - - -def gaudi_mpt_block_forward( - self, - hidden_states: torch.Tensor, - position_bias: torch.Tensor, - attention_mask: torch.Tensor, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - use_cache: bool = False, - output_attentions: bool = False, - token_idx: Optional[torch.Tensor] = None, - use_flash_attention: Optional[bool] = False, - flash_attention_recompute: Optional[bool] = False, -): - """ - Copied from MptBlock.forward: https://github.com/huggingface/transformers/blob/v4.32.0/src/transformers/models/mpt/modeling_mpt.py - The only differences are: - - add new args token_idx - - add new args use_flash_attention - - add new arg flash_attention_recompute - """ - # hidden_states: [batch_size, seq_length, hidden_size] - # Layer norm at the beginning of the transformer layer. - layernorm_output = self.norm_1(hidden_states) - - residual = hidden_states - - # Self attention. - attn_outputs, attn_weights, past_key_value = self.attn( - layernorm_output, - position_bias=position_bias, - attention_mask=attention_mask, - past_key_value=layer_past, - token_idx=token_idx, - use_flash_attention=use_flash_attention, - flash_attention_recompute=flash_attention_recompute, - ) - - hidden_states = self.resid_attn_dropout(attn_outputs) + residual - - layernorm_output = self.norm_2(hidden_states) - - # Get residual - residual = hidden_states - - # MLP. - output = self.ffn(layernorm_output, residual) - outputs = (output,) - - if use_cache: - outputs += (past_key_value,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs # hidden_states, present, attentions + if len(position_bias.shape) != 3: + raise ValueError(f"Expecting position_bias shape to be 3 dimensions, got {len(position_bias.shape)}") + key_length = key_states.shape[-2] + + position_bias_query_index = max(0, position_bias.size(1) - query_length) + position_bias_key_index = max(0, position_bias.size(2) - key_length) + + position_bias = position_bias[:, position_bias_query_index:, position_bias_key_index:] + + if use_flash_attention and FusedSDPA: + import habana_frameworks.torch.hpu as ht + + with ht.sdp_kernel(enable_recompute=flash_attention_recompute): + attn_output = FusedSDPA.apply( + query_states, + key_states, + value_states, + attention_mask * torch.finfo(query_states.dtype).min + position_bias.to(query_states.dtype), + 0.0, + False, + None, + ) + + attn_weights = None + else: + attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * self.softmax_scale + + if position_bias is not None: + attention_scores = attention_scores + position_bias + if attention_mask is not None: + attention_scores = attention_scores.masked_fill(attention_mask, torch.finfo(query_states.dtype).min) + + # (batch_size, n_heads, seq_length, key_length) + attn_weights = self.softmax(attention_scores.bfloat16(), dim=-1) + attn_weights = nn.functional.dropout(attn_weights, p=self.attn_dropout_p, training=self.training) + + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_length, -1) + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights, past_key_value + + +class GaudiMptBlock(MptBlock): + def __init__(self, config: MptConfig): + super().__init__(config) + self.attn = GaudiMptAttention(config) + + def forward( + self, + hidden_states: torch.Tensor, + position_bias: torch.Tensor, + attention_mask: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = False, + output_attentions: bool = False, + token_idx: Optional[torch.Tensor] = None, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + ): + """ + Copied from MptBlock.forward: https://github.com/huggingface/transformers/blob/v4.32.0/src/transformers/models/mpt/modeling_mpt.py + The only differences are: + - add new args token_idx + - add new args use_flash_attention + - add new arg flash_attention_recompute + """ + # hidden_states: [batch_size, seq_length, hidden_size] + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.norm_1(hidden_states) + + residual = hidden_states + + # Self attention. + attn_outputs, attn_weights, past_key_value = self.attn( + layernorm_output, + position_bias=position_bias, + attention_mask=attention_mask, + past_key_value=layer_past, + token_idx=token_idx, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + ) + + hidden_states = self.resid_attn_dropout(attn_outputs) + residual + + layernorm_output = self.norm_2(hidden_states) + + # Get residual + residual = hidden_states + + # MLP. + output = self.ffn(layernorm_output, residual) + outputs = (output,) + + if use_cache: + outputs += (past_key_value,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs # hidden_states, present, attentions class GaudiMptModel(MptModel): @@ -280,8 +303,6 @@ def forward( use_cache, output_attentions, None, - use_flash_attention, - flash_attention_recompute, ) else: outputs = block( @@ -340,6 +361,8 @@ def prepare_inputs_for_generation( - support for internal bucketing """ bucket_internal = kwargs.get("bucket_internal") + use_flash_attention = kwargs.get("use_flash_attention", False) + flash_attention_recompute = kwargs.get("flash_attention_recompute", False) # only last tokens for input_ids if past is not None if past_key_values is not None: if token_idx is None: @@ -375,8 +398,8 @@ def prepare_inputs_for_generation( "use_cache": use_cache, "attention_mask": attention_mask, "token_idx": token_idx, - "use_flash_attention": kwargs.get("use_flash_attention"), - "flash_attention_recompute": kwargs.get("flash_attention_recompute"), + "use_flash_attention": use_flash_attention, + "flash_attention_recompute": flash_attention_recompute, } ) return model_inputs