diff --git a/examples/text-generation/run_lm_eval.py b/examples/text-generation/run_lm_eval.py index 4d9b69bdee..fb579be19e 100644 --- a/examples/text-generation/run_lm_eval.py +++ b/examples/text-generation/run_lm_eval.py @@ -75,7 +75,7 @@ def __init__(self, tokenizer, model, args, options): self.options = options self._device = args.device self.model_inputs = {"use_cache": self.options.use_cache} - if self.model.config.model_type in ["llama", "mistral", "falcon", "phi", "mixtral", "qwen2"]: + if self.model.config.model_type in ["llama", "mistral", "falcon", "phi", "mixtral", "qwen2", "gptj"]: self.model_inputs.update( { "reuse_cache": self.options.reuse_cache, diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index a43ebf6375..f35f95dd65 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -858,7 +858,8 @@ def generate( "mixtral", "phi", "qwen2", - ], "reuse_cache only supported by llama, mistral, falcon, mixtral, phi and qwen2 at the moment" + "gptj", + ], "reuse_cache only supported by llama, mistral, falcon, mixtral, phi, qwen2 and gptj at the moment" if not generation_config.bucket_internal: assert ( generation_config.bucket_size <= 0 @@ -1014,7 +1015,7 @@ def generate( model_kwargs["kv_cache_len"] = calculated_max_length model_kwargs["kv_cache_pad_len"] = generation_config.max_new_tokens - if self.config.model_type in ["llama", "falcon", "mistral", "qwen2"]: + if self.config.model_type in ["llama", "falcon", "mistral", "qwen2", "gptj"]: if self.config.max_position_embeddings < calculated_max_length: unwrap_deepspeed_model(self).update_sincos_cache(seq_len=calculated_max_length) diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index aaa13df18f..34559c8351 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -50,6 +50,7 @@ GaudiGPTJAttention, GaudiGPTJBlock, GaudiGPTJForCausalLM, + GaudiGPTJModel, GaudiGPTNeoXForCausalLM, GaudiLlamaAttention, GaudiLlamaDecoderLayer, @@ -137,7 +138,6 @@ gaudi_gpt_neox_layer_forward, gaudi_gpt_neox_model_forward, gaudi_gpt_neox_rotary_embedding_set_cos_sin_cache, - gaudi_gptj_model_forward, gaudi_invert_attention_mask, gaudi_llama_rmsnorm_forward, gaudi_MambaForCausalLM_prepare_inputs_for_generation, @@ -341,7 +341,7 @@ def adapt_transformers_to_gaudi(): transformers.models.gptj.modeling_gptj.GPTJAttention = GaudiGPTJAttention transformers.models.gptj.modeling_gptj.GPTJForCausalLM = GaudiGPTJForCausalLM transformers.models.gptj.modeling_gptj.GPTJBlock = GaudiGPTJBlock - transformers.models.gptj.modeling_gptj.GPTJModel.forward = gaudi_gptj_model_forward + transformers.models.gptj.modeling_gptj.GPTJModel = GaudiGPTJModel # Optimization for GPTBigCode on Gaudi transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeAttention.forward = ( diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index 8a5f9ebf92..d6c0cc2fce 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -84,7 +84,7 @@ GaudiGPTJAttention, GaudiGPTJBlock, GaudiGPTJForCausalLM, - gaudi_gptj_model_forward, + GaudiGPTJModel, ) from .llama import ( GaudiLlamaAttention, diff --git a/optimum/habana/transformers/models/gptj/__init__.py b/optimum/habana/transformers/models/gptj/__init__.py index 23a1d6971b..e1f1c22f9d 100644 --- a/optimum/habana/transformers/models/gptj/__init__.py +++ b/optimum/habana/transformers/models/gptj/__init__.py @@ -1,6 +1,8 @@ +from transformers.models.gptj.configuration_gptj import GPTJConfig + from .modeling_gptj import ( GaudiGPTJAttention, GaudiGPTJBlock, GaudiGPTJForCausalLM, - gaudi_gptj_model_forward, + GaudiGPTJModel, ) diff --git a/optimum/habana/transformers/models/gptj/modeling_gptj.py b/optimum/habana/transformers/models/gptj/modeling_gptj.py index 0fae0f0467..a7d715b147 100644 --- a/optimum/habana/transformers/models/gptj/modeling_gptj.py +++ b/optimum/habana/transformers/models/gptj/modeling_gptj.py @@ -1,20 +1,115 @@ from typing import Optional, Tuple, Union +import habana_frameworks.torch.core as htcore import torch from torch import nn from torch.nn import CrossEntropyLoss from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.models.gptj.configuration_gptj import GPTJConfig from transformers.models.gptj.modeling_gptj import ( GPTJMLP, GPTJAttention, + GPTJBlock, GPTJForCausalLM, + GPTJModel, apply_rotary_pos_emb, create_sinusoidal_positions, logger, ) +class Matmul(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, *args, **kwargs): + return torch.matmul(*args, **kwargs) + + +class KVCache(torch.nn.Module): + def __init__(self): + super(KVCache, self).__init__() + self.cache = None + self.inp_seq_len = -1 + + def allocate(self, inp_seq_len, dtype, device, shape): + if self.cache is None or self.cache.shape != shape: + self.inp_seq_len = inp_seq_len + self.cache = torch.zeros(shape, dtype=dtype, device=device) + else: + assert ( + self.inp_seq_len == inp_seq_len + ), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}" + self.cache.fill_(0) + + def update(self, prev, cur, dim, idx, inp_seq_len): + orig_cur = cur + if prev.shape == cur.shape: + prev.copy_(cur) + return orig_cur + if cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]: + # Initialize + prev[:, :, :inp_seq_len, :].copy_(cur) + return orig_cur + assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" + if idx is not None: + prev.index_copy_(dim, idx - 1, cur) + return prev + else: + return torch.cat((prev, cur), dim=dim) + + def get_shape(self): + if self.cache is None: + return None + return self.cache.shape + + def forward(self, cur, dim, idx): + return self.update(self.cache, cur, dim, idx, self.inp_seq_len) + + class GaudiGPTJAttention(GPTJAttention): + def __init__(self, config: GPTJConfig): + super().__init__(config) + self.config = config + + self.matmul_qk = Matmul() + self.matmul_av = Matmul() + self.k_cache = KVCache() + self.v_cache = KVCache() + self.inp_seq_len = -1 + self.max_position_embeddings = config.max_position_embeddings + + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + cache_shape = (batch_size, self.num_attention_heads, max_seq_len, self.head_dim) + device = self.k_proj.weight.device + dtype = self.config.torch_dtype + self.k_cache.allocate(inp_seq_len, dtype, device, cache_shape) + self.v_cache.allocate(inp_seq_len, dtype, device, cache_shape) + + def reorder(self, tensor, beam_idx): + updated = tensor.index_select(0, beam_idx) + tensor.copy_(updated) + + def reorder_kv_cache(self, beam_idx: torch.LongTensor): + if self.k_cache.cache is None: + return (None, None) + + self.reorder(self.k_cache.cache, beam_idx) + self.reorder(self.v_cache.cache, beam_idx) + + return (self.k_cache.cache.shape, self.v_cache.cache.shape) + + def update_sincos_cache(self, seq_len): + # Call rotary emb forward() to update cos/sin cache when infering more than self.max_position_embeddings + # This helps in avoiding creation of these caches during actual model forward pass and + # reduce memory consumption and improve performance. + if seq_len > self.max_position_embeddings: + self.max_position_embeddings = seq_len + # Update register 'bias' buffer size + self.bias = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool)).view(1, 1, seq_len, seq_len) + # TODO: implement rotary_emb() + # _, _ = self.rotary_emb(self.k_proj.weight, seq_len=seq_len) + def _attn( self, query, @@ -31,7 +126,7 @@ def _attn( key = key.contiguous() value = value.contiguous() - attn_weights = torch.matmul(query, key.transpose(-1, -2)) + attn_weights = self.matmul_qk(query, key.transpose(-1, -2)) mask_value = torch.finfo(attn_weights.dtype).min # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. @@ -53,7 +148,7 @@ def _attn( if head_mask is not None: attn_weights = attn_weights * head_mask - attn_output = torch.matmul(attn_weights, value) + attn_output = self.matmul_av(attn_weights, value) return attn_output, attn_weights @@ -67,6 +162,8 @@ def forward( use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, token_idx: Optional[torch.Tensor] = None, + reuse_cache: Optional[bool] = False, + cache_idx: Optional[int] = None, sin: Optional[torch.Tensor] = None, cos: Optional[torch.Tensor] = None, ) -> Union[ @@ -74,13 +171,16 @@ def forward( Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]], ]: """ - Copied from GPTJAttention.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gptj/modeling_gptj.py + Copied from GPTJAttention.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gptj/modeling_gptj.py#194 The only differences are: - add new args token_idx + - add new args reuse_cache + - add new args cache_idx - remove is_torch_fx_proxy - optimize KV cache - pass sin and cos from upper level as they are identical for each attn block """ + _, q_len, _ = hidden_states.size() query = self.q_proj(hidden_states) key = self.k_proj(hidden_states) value = self.v_proj(hidden_states) @@ -108,7 +208,7 @@ def forward( key = key.permute(0, 2, 1, 3).contiguous() query = query.permute(0, 2, 1, 3).contiguous() - if layer_past is not None: + if layer_past is not None and not reuse_cache: past_key = layer_past[0] past_value = layer_past[1] @@ -121,10 +221,25 @@ def forward( key = torch.cat([past_key, key], dim=-2) value = torch.cat([past_value, value], dim=-2) - if use_cache is True: - # Note that this cast is quite ugly, but is not implemented before ROPE as the original codebase keeps the key in float32 all along the computation. - # Reference: https://github.com/kingoflolz/mesh-transformer-jax/blob/f8315e3003033b23f21d78361b288953064e0e76/mesh_transformer/layers.py#L128 - present = (key.to(hidden_states.dtype), value) + if use_cache is True and token_idx is not None: + if reuse_cache: + key = self.k_cache(key, 2, token_idx) + value = self.v_cache(value, 2, token_idx) + present = (self.k_cache.get_shape(), self.v_cache.get_shape()) + else: + if layer_past is None: + past_key = torch.zeros(key.shape, dtype=self.k_proj.weight.dtype, device=key.device) + past_value = torch.zeros(key.shape, dtype=self.k_proj.weight.dtype, device=key.device) + layer_past = (past_key, past_value) + key = self.k_cache.update(layer_past[0], key, 2, token_idx, self.inp_seq_len) + value = self.v_cache.update(layer_past[1], value, 2, token_idx, self.inp_seq_len) + present = layer_past + + if cache_idx is not None and q_len == 1: + key = key[:, :, :cache_idx, :] + value = value[:, :, :cache_idx, :] + if attention_mask is not None: + attention_mask = attention_mask[:, :, :, :cache_idx] else: present = None @@ -142,14 +257,27 @@ def forward( return outputs # a, present, (attentions) -class GaudiGPTJBlock(torch.nn.Module): - def __init__(self, config): - super().__init__() +class GaudiGPTJBlock(GPTJBlock): + """ + Inherits from GPTJBlock: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/gptj/modeling_gptj.py#291 + """ + + def __init__(self, config: GPTJConfig): + super().__init__(config) inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd self.ln_1 = torch.nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) self.attn = GaudiGPTJAttention(config) self.mlp = GPTJMLP(inner_dim, config) + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + self.attn.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) + + def reorder_kv_cache(self, beam_idx: torch.LongTensor): + return self.attn.reorder_kv_cache(beam_idx) + + def update_sincos_cache(self, seq_len): + self.attn.update_sincos_cache(seq_len) + def forward( self, hidden_states: Optional[torch.FloatTensor], @@ -160,6 +288,8 @@ def forward( use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, token_idx: Optional[torch.Tensor] = None, + reuse_cache: Optional[bool] = False, + cache_idx: Optional[int] = None, sin: Optional[torch.Tensor] = None, cos: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: @@ -180,6 +310,8 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, token_idx=token_idx, + reuse_cache=reuse_cache, + cache_idx=cache_idx, sin=sin, cos=cos, ) @@ -197,201 +329,206 @@ def forward( return outputs # hidden_states, present, (attentions) -def gaudi_gptj_model_forward( - self, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - token_idx: Optional[torch.Tensor] = None, - sin: Optional[torch.Tensor] = None, - cos: Optional[torch.Tensor] = None, -) -> Union[Tuple, BaseModelOutputWithPast]: +class GaudiGPTJModel(GPTJModel): """ - Copied from GPTJModel.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gptj/modeling_gptj.py - The only differences are: - - add new args token_idx - - pass sin and cos from upper level as they are identical for each attn block + Copied from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/gptj/modeling_gptj.py#L480 """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - batch_size = input_ids.shape[0] - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - batch_size = inputs_embeds.shape[0] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - device = input_ids.device if input_ids is not None else inputs_embeds.device - - if token_type_ids is not None: - token_type_ids = token_type_ids.view(-1, input_shape[-1]) - - if past_key_values is None: - past_length = 0 - past_key_values = tuple([None] * len(self.h)) - else: - past_length = past_key_values[0][0].size(-2) - - if position_ids is None: - position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0) - - # Attention mask. - if attention_mask is not None: - if batch_size <= 0: - raise ValueError("batch_size has to be defined and > 0") - attention_mask = attention_mask.view(batch_size, -1) - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask = attention_mask[:, None, None, :] - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and the dtype's smallest value for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility - attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x num_attention_heads x N x N - # head_mask has shape n_layer x batch x num_attention_heads x N x N - head_mask = self.get_head_mask(head_mask, self.config.n_layer) - - if inputs_embeds is None: - inputs_embeds = self.wte(input_ids) - - hidden_states = inputs_embeds - - if token_type_ids is not None: - token_type_embeds = self.wte(token_type_ids) - hidden_states = hidden_states + token_type_embeds - - hidden_states = self.drop(hidden_states) - - output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),) - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - presents = () if use_cache else None - all_self_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - - # replace original `_get_embed_positions` method and sin cos calculation in the attn block here to improve perf - rotary_dim = self.config.rotary_dim - embed_dim = self.config.hidden_size - pos_embd_dim = rotary_dim or embed_dim - max_positions = self.config.max_position_embeddings - embed_positions = create_sinusoidal_positions(max_positions, pos_embd_dim).to(torch.bfloat16) - embed_positions = embed_positions.repeat(position_ids.shape[0], 1, 1) - if embed_positions.device != position_ids.device: - embed_positions = embed_positions.to(position_ids.device) - repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1]) - sincos = torch.gather(embed_positions, 1, repeated_position_ids) - sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1) - sin = sin.contiguous() - cos = cos.contiguous() - - for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): - # Model parallel - if self.model_parallel: - torch.cuda.set_device(hidden_states.device) - # Ensure layer_past is on same device as hidden_states (might not be correct) - if layer_past is not None: - layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) - # Ensure that attention_mask is always on the same device as hidden_states - if attention_mask is not None: - attention_mask = attention_mask.to(hidden_states.device) - if isinstance(head_mask, torch.Tensor): - head_mask = head_mask.to(hidden_states.device) - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - outputs = self._gradient_checkpointing_func( - block.__call__, - hidden_states, - None, - attention_mask, - position_ids, - head_mask[i], - use_cache, - output_attentions, - None, - sin, - cos, - ) + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + for layer in self.h: + layer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) + + def reorder_kv_cache(self, beam_idx: torch.LongTensor): + return tuple(layer.reorder_kv_cache(beam_idx) for layer in self.h) + + def update_sincos_cache(self, seq_len): + for layer in self.h: + layer.update_sincos_cache(seq_len) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + token_idx: Optional[torch.Tensor] = None, + reuse_cache: Optional[bool] = False, + cache_idx: Optional[int] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + """ + Copied from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/gptj/modeling_gptj.py#L554 + The only differences are: + - add new args token_idx + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] else: - outputs = block( - hidden_states=hidden_states, - layer_past=layer_past, - attention_mask=attention_mask, - position_ids=position_ids, - head_mask=head_mask[i], - use_cache=use_cache, - output_attentions=output_attentions, - token_idx=token_idx, - sin=sin, - cos=cos, - ) - - hidden_states = outputs[0] - if use_cache is True: - presents = presents + (outputs[1],) + raise ValueError("You have to specify either input_ids or inputs_embeds") - if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + device = input_ids.device if input_ids is not None else inputs_embeds.device - # Model Parallel: If it's the last layer for that device, put things on the next device - if self.model_parallel: - for k, v in self.device_map.items(): - if i == v[-1] and "cuda:" + str(k) != self.last_device: - hidden_states = hidden_states.to("cuda:" + str(k + 1)) + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) - hidden_states = self.ln_f(hidden_states) + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + if reuse_cache: + past_length = past_key_values[0][0][-2] + else: + past_length = past_key_values[0][0].size(-2) - hidden_states = hidden_states.view(output_shape) - # Add last hidden state - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) + if position_ids is None: + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0) - if not return_dict: - return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + # Attention mask. + if attention_mask is not None: + if batch_size <= 0: + raise ValueError("batch_size has to be defined and > 0") + attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x num_attention_heads x N x N + # head_mask has shape n_layer x batch x num_attention_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + + hidden_states = inputs_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - ) + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + # replace original `_get_embed_positions` method and sin cos calculation in the attn block here to improve perf + rotary_dim = self.config.rotary_dim + embed_dim = self.config.hidden_size + pos_embd_dim = rotary_dim or embed_dim + max_positions = self.config.max_position_embeddings + embed_positions = create_sinusoidal_positions(max_positions, pos_embd_dim).to(torch.bfloat16) + embed_positions = embed_positions.repeat(position_ids.shape[0], 1, 1) + if embed_positions.device != position_ids.device: + embed_positions = embed_positions.to(position_ids.device) + repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1]) + sincos = torch.gather(embed_positions, 1, repeated_position_ids) + sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1) + sin = sin.contiguous() + cos = cos.contiguous() + + htcore.mark_step() + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + htcore.mark_step() + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + outputs = self._gradient_checkpointing_func( + block.__call__, + hidden_states, + None, + attention_mask, + position_ids, + head_mask[i], + use_cache, + output_attentions, + None, + sin, + cos, + ) + else: + outputs = block( + hidden_states=hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + token_idx=token_idx, + reuse_cache=reuse_cache, + cache_idx=cache_idx, + sin=sin, + cos=cos, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + + hidden_states = self.ln_f(hidden_states) + + hidden_states = hidden_states.view(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) class GaudiGPTJForCausalLM(GPTJForCausalLM): @@ -405,10 +542,21 @@ class GaudiGPTJForCausalLM(GPTJForCausalLM): - from step2 when enable KV cache, slice next_token_type_ids from token_type_ids base on the token_idx """ + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + self.transformer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) + + def reorder_kv_cache(self, beam_idx: torch.LongTensor): + return self.transformer.reorder_kv_cache(beam_idx) + + def update_sincos_cache(self, seq_len): + self.transformer.update_sincos_cache(seq_len) + def prepare_inputs_for_generation( self, input_ids, past_key_values=None, inputs_embeds=None, token_idx=None, **kwargs ): + reuse_cache = kwargs.get("reuse_cache") token_type_ids = kwargs.get("token_type_ids", None) + attention_mask = kwargs.get("attention_mask", None) # Omit tokens covered by past_key_values if past_key_values: if token_idx is not None: @@ -430,8 +578,11 @@ def prepare_inputs_for_generation( token_type_ids = torch.index_select(token_type_ids, 1, token_idx - 1) else: token_type_ids = token_type_ids[:, -input_ids.shape[1] :] + elif reuse_cache and token_idx is not None: + # With reuse_cache, KV cache is pre allocated hence for the 1st token we can slice the inputs till token idx for the fwd pass + input_ids = input_ids[:, :token_idx] + attention_mask = attention_mask[:, :token_idx] - attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: @@ -458,6 +609,8 @@ def prepare_inputs_for_generation( "attention_mask": attention_mask, "token_type_ids": token_type_ids, "token_idx": token_idx, + "reuse_cache": reuse_cache, + "cache_idx": kwargs.get("cache_idx"), } ) @@ -478,6 +631,8 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, token_idx: Optional[torch.Tensor] = None, + reuse_cache: Optional[bool] = False, + cache_idx: Optional[int] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -500,14 +655,11 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, token_idx=token_idx, + reuse_cache=reuse_cache, + cache_idx=cache_idx, ) hidden_states = transformer_outputs[0] - # Set device for model parallelism - if self.model_parallel: - torch.cuda.set_device(self.transformer.first_device) - hidden_states = hidden_states.to(self.lm_head.weight.device) - # make sure sampling in fp16 works correctly and # compute loss in fp32 to match with mesh-tf version # https://github.com/EleutherAI/gpt-neo/blob/89ce74164da2fb16179106f54e2269b5da8db333/models/gpt2/gpt2.py#L179 diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py index 4e116242f5..7761717f3d 100644 --- a/tests/test_text_generation_example.py +++ b/tests/test_text_generation_example.py @@ -80,7 +80,7 @@ ("gpt2-xl", 1, False, 142.11481820425706), # TODO: fix OPT 6.7B # ("facebook/opt-6.7b", 0.0), - ("EleutherAI/gpt-j-6b", 1, False, 50.79545107991805), + ("EleutherAI/gpt-j-6b", 1, True, 156.2893125740893), ("meta-llama/Llama-2-7b-hf", 1, True, 44.39616259946937), ("tiiuae/falcon-7b", 1, True, 44.82870145718665), ("bigcode/starcoder", 1, False, 15.945023767901013),