From 05169df1d6ed554bb729da1a0bd382d59192efff Mon Sep 17 00:00:00 2001 From: Abhinav Venigalla Date: Mon, 17 Oct 2022 22:20:29 +0000 Subject: [PATCH 01/18] wip --- llm/hf_configs/noflash-gpt-125m-ctx-1024.json | 38 + llm/llm/gpt.py | 14 +- llm/llm/hf_flash_gpt.py | 1592 ----------------- llm/llm/hf_flash_gpt_2.py | 10 +- 4 files changed, 55 insertions(+), 1599 deletions(-) create mode 100644 llm/hf_configs/noflash-gpt-125m-ctx-1024.json delete mode 100644 llm/llm/hf_flash_gpt.py diff --git a/llm/hf_configs/noflash-gpt-125m-ctx-1024.json b/llm/hf_configs/noflash-gpt-125m-ctx-1024.json new file mode 100644 index 000000000..11e00e424 --- /dev/null +++ b/llm/hf_configs/noflash-gpt-125m-ctx-1024.json @@ -0,0 +1,38 @@ +{ + "activation_function": "gelu", + "architectures": [ + "GPT2FlashLMHeadModel" + ], + "attn_pdrop": 0.1, + "bos_token_id": 50256, + "embd_pdrop": 0.1, + "eos_token_id": 50256, + "initializer_range": 0.02, + "layer_norm_epsilon": 1e-05, + "model_type": "gpt2", + "n_ctx": 1024, + "n_embd": 768, + "n_head": 12, + "n_inner": null, + "n_layer": 12, + "n_positions": 1024, + "reorder_and_upcast_attn": false, + "resid_pdrop": 0.1, + "scale_attn_by_inverse_layer_idx": false, + "scale_attn_weights": true, + "summary_activation": null, + "summary_first_dropout": 0.1, + "summary_proj_to_labels": true, + "summary_type": "cls_index", + "summary_use_proj": true, + "task_specific_params": { + "text-generation": { + "do_sample": true, + "max_length": 50 + } + }, + "torch_dtype": "float32", + "transformers_version": "4.12.3", + "use_cache": false, + "vocab_size": 50257 +} diff --git a/llm/llm/gpt.py b/llm/llm/gpt.py index f7c14744e..7850d1825 100644 --- a/llm/llm/gpt.py +++ b/llm/llm/gpt.py @@ -15,7 +15,7 @@ from composer.models.base import ComposerModel from flash_attn.flash_attention import FlashMHA from transformers.models.gpt2 import GPT2Config - +from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel from .hf_flash_gpt_2 import GPT2FlashLMHeadModel @@ -26,8 +26,14 @@ def __init__(self, cfg, device='meta'): # load GPT2 config from standard HF model config json hf_config = GPT2Config.from_json_file(cfg.hf_config) # build model with config - self.model = GPT2FlashLMHeadModel(hf_config) - self.model.to(device) + model_class = hf_config.architectures[0] + if model_class == 'GPT2LMHeadModel': + self.model = GPT2LMHeadModel(hf_config) + elif model_class == 'GPT2FlashLMHeadModel': + self.model = GPT2FlashLMHeadModel(hf_config) + self.model.to(device) + else: + raise ValueError(f'Not sure how to build model_class={model_class}') self.train_metrics = { 'LanguageCrossEntropy': LanguageCrossEntropy(hf_config.vocab_size), 'Perplexity': Perplexity(), @@ -60,4 +66,4 @@ def get_metrics(self, is_train=False): def update_metric(self, batch, outputs, metric): outputs = outputs.view(-1, outputs.size(-1)) targets = self.get_targets(batch).view(-1) - metric.update(outputs, targets) + metric.update(outputs, targets) \ No newline at end of file diff --git a/llm/llm/hf_flash_gpt.py b/llm/llm/hf_flash_gpt.py deleted file mode 100644 index 79acec75c..000000000 --- a/llm/llm/hf_flash_gpt.py +++ /dev/null @@ -1,1592 +0,0 @@ -# coding=utf-8 -# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Modified HF GPT2 w/flash attention""" - -from einops import rearrange -from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func - -import math -import os -from dataclasses import dataclass -from typing import Optional, Tuple, Union - -import torch -import torch.utils.checkpoint -from torch import nn -from torch.cuda.amp import autocast -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss - -from transformers.activations import ACT2FN -from transformers.modeling_outputs import ( - BaseModelOutputWithPastAndCrossAttentions, - CausalLMOutputWithCrossAttentions, - SequenceClassifierOutputWithPast, - TokenClassifierOutput, -) -from transformers.modeling_utils import PreTrainedModel, SequenceSummary -from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer -from transformers.utils import ( - ModelOutput, - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) -from transformers.utils.model_parallel_utils import assert_device_map, get_device_map -from transformers.models.gpt2.configuration_gpt2 import GPT2Config - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "gpt2" -_CONFIG_FOR_DOC = "GPT2Config" -_TOKENIZER_FOR_DOC = "GPT2Tokenizer" - -GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "gpt2", - "gpt2-medium", - "gpt2-large", - "gpt2-xl", - "distilgpt2", - # See all GPT-2 models at https://huggingface.co/models?filter=gpt2 -] - - -def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): - """Load tf checkpoints in a pytorch model""" - try: - import re - - import tensorflow as tf - except ImportError: - logger.error( - "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions." - ) - raise - tf_path = os.path.abspath(gpt2_checkpoint_path) - logger.info(f"Converting TensorFlow checkpoint from {tf_path}") - # Load weights from TF model - init_vars = tf.train.list_variables(tf_path) - names = [] - arrays = [] - for name, shape in init_vars: - logger.info(f"Loading TF weight {name} with shape {shape}") - array = tf.train.load_variable(tf_path, name) - names.append(name) - arrays.append(array.squeeze()) - - for name, array in zip(names, arrays): - name = name[6:] # skip "model/" - name = name.split("/") - pointer = model - for m_name in name: - if re.fullmatch(r"[A-Za-z]+\d+", m_name): - scope_names = re.split(r"(\d+)", m_name) - else: - scope_names = [m_name] - if scope_names[0] == "w" or scope_names[0] == "g": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "b": - pointer = getattr(pointer, "bias") - elif scope_names[0] == "wpe" or scope_names[0] == "wte": - pointer = getattr(pointer, scope_names[0]) - pointer = getattr(pointer, "weight") - else: - pointer = getattr(pointer, scope_names[0]) - if len(scope_names) >= 2: - num = int(scope_names[1]) - pointer = pointer[num] - try: - assert ( - pointer.shape == array.shape - ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" - except AssertionError as e: - e.args += (pointer.shape, array.shape) - raise - logger.info(f"Initialize PyTorch weight {name}") - pointer.data = torch.from_numpy(array) - return model - - -class GPT2FlashAttention(nn.Module): - def __init__(self, config, is_cross_attention=False, layer_idx=None): - super().__init__() - - max_positions = config.max_position_embeddings - self.register_buffer( - "bias", - torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( - 1, 1, max_positions, max_positions - ), - ) - self.register_buffer("masked_bias", torch.tensor(-1e4)) - - self.embed_dim = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.embed_dim // self.num_heads - self.split_size = self.embed_dim - if self.head_dim * self.num_heads != self.embed_dim: - raise ValueError( - f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" - f" {self.num_heads})." - ) - - self.scale_attn_weights = config.scale_attn_weights - self.is_cross_attention = is_cross_attention - - # Layer-wise attention scaling, reordering, and upcasting - self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx - self.layer_idx = layer_idx - self.reorder_and_upcast_attn = config.reorder_and_upcast_attn - - if self.is_cross_attention: - self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim) - self.q_attn = Conv1D(self.embed_dim, self.embed_dim) - else: - self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) - self.c_proj = Conv1D(self.embed_dim, self.embed_dim) - - self.attn_dropout = nn.Dropout(config.attn_pdrop) - self.resid_dropout = nn.Dropout(config.resid_pdrop) - - self.pruned_heads = set() - - # FSDP Wrap function - def fsdp_wrap_fn(self, module): - return isinstance(module, GPT2FlashBlock) - - # Activation Checkpointing - def activation_checkpointing_fn(self, module): - return isinstance(module, GPT2FlashBlock) - - def prune_heads(self, heads): - if len(heads) == 0: - return - heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads) - index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)]) - - # Prune conv1d layers - self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) - self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) - - # Update hyper params - self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads)) - self.num_heads = self.num_heads - len(heads) - self.pruned_heads = self.pruned_heads.union(heads) - - def _attn(self, query, key, value, attention_mask=None, head_mask=None): - # rearrange to flash attention form - key = rearrange(key, 'b h s d -> b s h d') - value = rearrange(value, 'b h s d -> b s h d') - query = rearrange(query, 'b h s d -> b s h d') - - #assert query.dtype in [torch.float16, torch.bfloat16], f"{query.dtype}" - - # stack - qkv = torch.stack([query,key,value], dim=2) - #qkv = torch.tensor(qkv,dtype=torch.bfloat16) - assert qkv.dtype in [torch.float16, torch.bfloat16] - - # flash attention logic - batch_size = qkv.shape[0] - seqlen = qkv.shape[1] - num_heads = qkv.shape[3] - dk = qkv.shape[4] - dk_per_head = int(dk)/int(num_heads) - qkv = rearrange(qkv, 'b s ... -> (b s) ...') - max_s = seqlen - cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, device=qkv.device) - attn_pdrop = 0.1 - softmax_scale = 1/float(math.sqrt(dk)) - output = flash_attn_unpadded_qkvpacked_func( - qkv, cu_seqlens, max_s, attn_pdrop, - softmax_scale=softmax_scale, causal=True - ) - output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) - output = rearrange(output, 'b s h d -> b h s d') - #output = torch.tensor(output, dtype=torch.float32) - return output, None - - - #attn_weights = torch.matmul(query, key.transpose(-1, -2)) - - #if self.scale_attn_weights: - #attn_weights = attn_weights / torch.tensor( - #value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device - #) - - # Layer-wise attention scaling - #if self.scale_attn_by_inverse_layer_idx: - #attn_weights = attn_weights / float(self.layer_idx + 1) - - #if not self.is_cross_attention: - # if only "normal" attention layer implements causal mask - #query_length, key_length = query.size(-2), key.size(-2) - #causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool) - #mask_value = torch.finfo(attn_weights.dtype).min - # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. - # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` - #mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) - #attn_weights = torch.where(causal_mask, attn_weights, mask_value) - - #if attention_mask is not None: - # Apply the attention mask - #attn_weights = attn_weights + attention_mask - - #attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise - #attn_weights = attn_weights.type(value.dtype) - #attn_weights = self.attn_dropout(attn_weights) - - # Mask heads if we want to - #if head_mask is not None: - #attn_weights = attn_weights * head_mask - - #attn_output = torch.matmul(attn_weights, value) - - #return attn_output, attn_weights - - def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None): - # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM) - bsz, num_heads, q_seq_len, dk = query.size() - _, _, k_seq_len, _ = key.size() - - # Preallocate attn_weights for `baddbmm` - attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device) - - # Compute Scale Factor - scale_factor = 1.0 - if self.scale_attn_weights: - scale_factor /= float(value.size(-1)) ** 0.5 - - if self.scale_attn_by_inverse_layer_idx: - scale_factor /= float(self.layer_idx + 1) - - # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk)) - with autocast(enabled=False): - q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) - attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) - attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) - - if not self.is_cross_attention: - # if only "normal" attention layer implements causal mask - query_length, key_length = query.size(-2), key.size(-2) - causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() - mask_value = torch.finfo(attn_weights.dtype).min - # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. - # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` - mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) - attn_weights = torch.where(causal_mask, attn_weights, mask_value) - - if attention_mask is not None: - # Apply the attention mask - attn_weights = attn_weights + attention_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise - if attn_weights.dtype != torch.float32: - raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32") - attn_weights = attn_weights.type(value.dtype) - attn_weights = self.attn_dropout(attn_weights) - - # Mask heads if we want to - if head_mask is not None: - attn_weights = attn_weights * head_mask - - attn_output = torch.matmul(attn_weights, value) - - return attn_output, attn_weights - - def _split_heads(self, tensor, num_heads, attn_head_size): - """ - Splits hidden_size dim into attn_head_size and num_heads - """ - new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) - tensor = tensor.view(new_shape) - return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) - - def _merge_heads(self, tensor, num_heads, attn_head_size): - """ - Merges attn_head_size dim and num_attn_heads dim into hidden_size - """ - tensor = tensor.permute(0, 2, 1, 3).contiguous() - new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) - return tensor.view(new_shape) - - def forward( - self, - hidden_states: Optional[Tuple[torch.FloatTensor]], - layer_past: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: - if encoder_hidden_states is not None: - if not hasattr(self, "q_attn"): - raise ValueError( - "If class is used as cross attention, the weights `q_attn` have to be defined. " - "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." - ) - - query = self.q_attn(hidden_states) - key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) - attention_mask = encoder_attention_mask - else: - query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) - - query = self._split_heads(query, self.num_heads, self.head_dim) - key = self._split_heads(key, self.num_heads, self.head_dim) - value = self._split_heads(value, self.num_heads, self.head_dim) - - - if layer_past is not None: - past_key, past_value = layer_past - key = torch.cat((past_key, key), dim=-2) - value = torch.cat((past_value, value), dim=-2) - - if use_cache is True: - present = (key, value) - else: - present = None - - if self.reorder_and_upcast_attn: - attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask) - else: - attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) - - attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) - attn_output = self.c_proj(attn_output) - attn_output = self.resid_dropout(attn_output) - - outputs = (attn_output, present) - if output_attentions: - outputs += (attn_weights,) - - return outputs # a, present, (attentions) - - -class GPT2MLP(nn.Module): - def __init__(self, intermediate_size, config): - super().__init__() - embed_dim = config.hidden_size - self.c_fc = Conv1D(intermediate_size, embed_dim) - self.c_proj = Conv1D(embed_dim, intermediate_size) - self.act = ACT2FN[config.activation_function] - self.dropout = nn.Dropout(config.resid_pdrop) - - def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: - hidden_states = self.c_fc(hidden_states) - hidden_states = self.act(hidden_states) - hidden_states = self.c_proj(hidden_states) - hidden_states = self.dropout(hidden_states) - return hidden_states - - -class GPT2FlashBlock(nn.Module): - def __init__(self, config, layer_idx=None): - super().__init__() - hidden_size = config.hidden_size - inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size - - self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = GPT2FlashAttention(config, layer_idx=layer_idx) - self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - - if config.add_cross_attention: - self.crossattention = GPT2FlashAttention(config, is_cross_attention=True, layer_idx=layer_idx) - self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - - self.mlp = GPT2MLP(inner_dim, config) - - def forward( - self, - hidden_states: Optional[Tuple[torch.FloatTensor]], - layer_past: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: - residual = hidden_states - hidden_states = self.ln_1(hidden_states) - attn_outputs = self.attn( - hidden_states, - layer_past=layer_past, - attention_mask=attention_mask, - head_mask=head_mask, - use_cache=use_cache, - output_attentions=output_attentions, - ) - attn_output = attn_outputs[0] # output_attn: a, present, (attentions) - outputs = attn_outputs[1:] - # residual connection - hidden_states = attn_output + residual - - if encoder_hidden_states is not None: - # add one self-attention block for cross-attention - if not hasattr(self, "crossattention"): - raise ValueError( - f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " - "cross-attention layers by setting `config.add_cross_attention=True`" - ) - residual = hidden_states - hidden_states = self.ln_cross_attn(hidden_states) - cross_attn_outputs = self.crossattention( - hidden_states, - attention_mask=attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - output_attentions=output_attentions, - ) - attn_output = cross_attn_outputs[0] - # residual connection - hidden_states = residual + attn_output - outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights - - residual = hidden_states - hidden_states = self.ln_2(hidden_states) - feed_forward_hidden_states = self.mlp(hidden_states) - # residual connection - hidden_states = residual + feed_forward_hidden_states - - if use_cache: - outputs = (hidden_states,) + outputs - else: - outputs = (hidden_states,) + outputs[1:] - - return outputs # hidden_states, present, (attentions, cross_attentions) - - -class GPT2PreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = GPT2Config - load_tf_weights = load_tf_weights_in_gpt2 - base_model_prefix = "transformer" - is_parallelizable = True - supports_gradient_checkpointing = True - _no_split_modules = ["GPT2Block"] - - def __init__(self, *inputs, **kwargs): - super().__init__(*inputs, **kwargs) - - def param_init_fn(self, module): - self._init_weights(module) - - def _init_weights(self, module): - """Initialize the weights.""" - if isinstance(module, (nn.Linear, Conv1D)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - - # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: - # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale - # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. - # > -- GPT-2 :: https://openai.com/blog/better-language-models/ - # - # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py - for name, p in module.named_parameters(): - if name == "c_proj.weight": - # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block - p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, GPT2Model): - module.gradient_checkpointing = value - - -@dataclass -class GPT2DoubleHeadsModelOutput(ModelOutput): - """ - Base class for outputs of models predicting if two sentences are consecutive or not. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Language modeling loss. - mc_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mc_labels` is provided): - Multiple choice classification loss. - logits (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - mc_logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`): - Prediction scores of the multiple choice classification head (scores for each choice before SoftMax). - past_key_values (`Tuple[Tuple[torch.Tensor]]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of length `config.n_layers`, containing tuples of tensors of shape `(batch_size, num_heads, - sequence_length, embed_size_per_head)`). - - Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see - `past_key_values` input) to speed up sequential decoding. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of - shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - GPT2Attentions weights after the attention softmax, used to compute the weighted average in the - self-attention heads. - """ - - loss: Optional[torch.FloatTensor] = None - mc_loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - mc_logits: torch.FloatTensor = None - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - hidden_states: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None - - -GPT2_START_DOCSTRING = r""" - - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`GPT2Config`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -GPT2_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): - `input_ids_length` = `sequence_length` if `past_key_values` is `None` else - `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input - sequence tokens in the vocabulary. - - If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as - `input_ids`. - - Indices can be obtained using [`GPT2Tokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`): - Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see - `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have - their past given to this model should not be passed as `input_ids` as they have already been computed. - attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for - `past_key_values`. In other words, the `attention_mask` always has to have the length: - `len(past_key_values) + len(input_ids)` - - [What are attention masks?](../glossary#attention-mask) - token_type_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*): - Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, - 1]`: - - - 0 corresponds to a *sentence A* token, - - 1 corresponds to a *sentence B* token. - - [What are token type IDs?](../glossary#token-type-ids) - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) - head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - - If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see - `past_key_values`). - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" -PARALLELIZE_DOCSTRING = r""" - This is an experimental feature and is a subject to change at a moment's notice. - - Uses a device map to distribute attention modules of the model across several devices. If no device map is given, - it will evenly distribute blocks across all devices. - - Args: - device_map (`Dict[int, list]`, optional, defaults to None): - A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always - automatically mapped to the first device (for esoteric reasons). That means that the first device should - have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the - following number of attention modules: - - - gpt2: 12 - - gpt2-medium: 24 - - gpt2-large: 36 - - gpt2-xl: 48 - - Example: - - ```python - # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules: - model = GPT2LMHeadModel.from_pretrained("gpt2-xl") - device_map = { - 0: [0, 1, 2, 3, 4, 5, 6, 7, 8], - 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21], - 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34], - 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], - } - model.parallelize(device_map) - ``` -""" -DEPARALLELIZE_DOCSTRING = r""" - Moves the model to cpu from a model parallel state. - - Example: - - ```python - # On a 4 GPU machine with gpt2-large: - model = GPT2LMHeadModel.from_pretrained("gpt2-large") - device_map = { - 0: [0, 1, 2, 3, 4, 5, 6, 7], - 1: [8, 9, 10, 11, 12, 13, 14, 15], - 2: [16, 17, 18, 19, 20, 21, 22, 23], - 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35], - } - model.parallelize(device_map) # Splits the model across several devices - model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() - ``` -""" - - -@add_start_docstrings( - "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.", - GPT2_START_DOCSTRING, -) -class GPT2FlashModel(GPT2PreTrainedModel): - _keys_to_ignore_on_load_missing = ["attn.masked_bias"] - - def __init__(self, config): - super().__init__(config) - - self.embed_dim = config.hidden_size - - self.wte = nn.Embedding(config.vocab_size, self.embed_dim) - self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) - - self.drop = nn.Dropout(config.embd_pdrop) - self.h = nn.ModuleList([GPT2FlashBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)]) - self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) - - # Model parallel - self.model_parallel = False - self.device_map = None - self.gradient_checkpointing = False - - # Initialize weights and apply final processing - self.post_init() - - @add_start_docstrings(PARALLELIZE_DOCSTRING) - def parallelize(self, device_map=None): - # Check validity of device_map - self.device_map = ( - get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map - ) - assert_device_map(self.device_map, len(self.h)) - self.model_parallel = True - self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) - self.last_device = "cuda:" + str(max(self.device_map.keys())) - self.wte = self.wte.to(self.first_device) - self.wpe = self.wpe.to(self.first_device) - # Load onto devices - for k, v in self.device_map.items(): - for block in v: - cuda_device = "cuda:" + str(k) - self.h[block] = self.h[block].to(cuda_device) - # ln_f to last - self.ln_f = self.ln_f.to(self.last_device) - - @add_start_docstrings(DEPARALLELIZE_DOCSTRING) - def deparallelize(self): - self.model_parallel = False - self.device_map = None - self.first_device = "cpu" - self.last_device = "cpu" - self.wte = self.wte.to("cpu") - self.wpe = self.wpe.to("cpu") - for index in range(len(self.h)): - self.h[index] = self.h[index].to("cpu") - self.ln_f = self.ln_f.to("cpu") - torch.cuda.empty_cache() - - def get_input_embeddings(self): - return self.wte - - def set_input_embeddings(self, new_embeddings): - self.wte = new_embeddings - - def _prune_heads(self, heads_to_prune): - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} - """ - for layer, heads in heads_to_prune.items(): - self.h[layer].attn.prune_heads(heads) - - @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - processor_class=_TOKENIZER_FOR_DOC, - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=BaseModelOutputWithPastAndCrossAttentions, - config_class=_CONFIG_FOR_DOC, - ) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - batch_size = input_ids.shape[0] - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - batch_size = inputs_embeds.shape[0] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - device = input_ids.device if input_ids is not None else inputs_embeds.device - - if token_type_ids is not None: - token_type_ids = token_type_ids.view(-1, input_shape[-1]) - if position_ids is not None: - position_ids = position_ids.view(-1, input_shape[-1]) - - if past_key_values is None: - past_length = 0 - past_key_values = tuple([None] * len(self.h)) - else: - past_length = past_key_values[0][0].size(-2) - if position_ids is None: - position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) - - # GPT2Attention mask. - if attention_mask is not None: - if batch_size <= 0: - raise ValueError("batch_size has to be defined and > 0") - attention_mask = attention_mask.view(batch_size, -1) - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask = attention_mask[:, None, None, :] - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and -10000.0 for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility - attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min - - # If a 2D or 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - if self.config.add_cross_attention and encoder_hidden_states is not None: - encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() - encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) - if encoder_attention_mask is None: - encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) - else: - encoder_attention_mask = None - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # head_mask has shape n_layer x batch x n_heads x N x N - head_mask = self.get_head_mask(head_mask, self.config.n_layer) - - if inputs_embeds is None: - inputs_embeds = self.wte(input_ids) - position_embeds = self.wpe(position_ids) - hidden_states = inputs_embeds + position_embeds - - if token_type_ids is not None: - token_type_embeds = self.wte(token_type_ids) - hidden_states = hidden_states + token_type_embeds - - hidden_states = self.drop(hidden_states) - - output_shape = input_shape + (hidden_states.size(-1),) - - presents = () if use_cache else None - all_self_attentions = () if output_attentions else None - all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - all_hidden_states = () if output_hidden_states else None - for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): - - # Model parallel - if self.model_parallel: - torch.cuda.set_device(hidden_states.device) - # Ensure layer_past is on same device as hidden_states (might not be correct) - if layer_past is not None: - layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) - # Ensure that attention_mask is always on the same device as hidden_states - if attention_mask is not None: - attention_mask = attention_mask.to(hidden_states.device) - if isinstance(head_mask, torch.Tensor): - head_mask = head_mask.to(hidden_states.device) - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if self.gradient_checkpointing and self.training: - - if use_cache: - logger.warning( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - None, - attention_mask, - head_mask[i], - encoder_hidden_states, - encoder_attention_mask, - ) - else: - outputs = block( - hidden_states, - layer_past=layer_past, - attention_mask=attention_mask, - head_mask=head_mask[i], - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - ) - - hidden_states = outputs[0] - if use_cache is True: - presents = presents + (outputs[1],) - - if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) - if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) - - # Model Parallel: If it's the last layer for that device, put things on the next device - if self.model_parallel: - for k, v in self.device_map.items(): - if i == v[-1] and "cuda:" + str(k) != self.last_device: - hidden_states = hidden_states.to("cuda:" + str(k + 1)) - - hidden_states = self.ln_f(hidden_states) - - hidden_states = hidden_states.view(output_shape) - # Add last hidden state - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple( - v - for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] - if v is not None - ) - - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - cross_attentions=all_cross_attentions, - ) - - -@add_start_docstrings( - """ - The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input - embeddings). - """, - GPT2_START_DOCSTRING, -) -class GPT2FlashLMHeadModel(GPT2PreTrainedModel): - _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"] - - def __init__(self, config): - super().__init__(config) - self.transformer = GPT2FlashModel(config) - self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) - - # Model parallel - self.model_parallel = False - self.device_map = None - - # Initialize weights and apply final processing - self.post_init() - - @add_start_docstrings(PARALLELIZE_DOCSTRING) - def parallelize(self, device_map=None): - self.device_map = ( - get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) - if device_map is None - else device_map - ) - assert_device_map(self.device_map, len(self.transformer.h)) - self.transformer.parallelize(self.device_map) - self.lm_head = self.lm_head.to(self.transformer.first_device) - self.model_parallel = True - - @add_start_docstrings(DEPARALLELIZE_DOCSTRING) - def deparallelize(self): - self.transformer.deparallelize() - self.transformer = self.transformer.to("cpu") - self.lm_head = self.lm_head.to("cpu") - self.model_parallel = False - torch.cuda.empty_cache() - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): - token_type_ids = kwargs.get("token_type_ids", None) - # only last token for inputs_ids if past is defined in kwargs - if past: - input_ids = input_ids[:, -1].unsqueeze(-1) - if token_type_ids is not None: - token_type_ids = token_type_ids[:, -1].unsqueeze(-1) - - attention_mask = kwargs.get("attention_mask", None) - position_ids = kwargs.get("position_ids", None) - - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past: - position_ids = position_ids[:, -1].unsqueeze(-1) - else: - position_ids = None - return { - "input_ids": input_ids, - "past_key_values": past, - "use_cache": kwargs.get("use_cache"), - "position_ids": position_ids, - "attention_mask": attention_mask, - "token_type_ids": token_type_ids, - } - - @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - processor_class=_TOKENIZER_FOR_DOC, - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=CausalLMOutputWithCrossAttentions, - config_class=_CONFIG_FOR_DOC, - ) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set - `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` - are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.transformer( - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = transformer_outputs[0] - - # Set device for model parallelism - if self.model_parallel: - torch.cuda.set_device(self.transformer.first_device) - hidden_states = hidden_states.to(self.lm_head.weight.device) - - lm_logits = self.lm_head(hidden_states) - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - - if not return_dict: - output = (lm_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return CausalLMOutputWithCrossAttentions( - loss=loss, - logits=lm_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - cross_attentions=transformer_outputs.cross_attentions, - ) - - @staticmethod - def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: - """ - This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or - [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct - beam_idx at every generation step. - """ - return tuple( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) - for layer_past in past - ) - - -@add_start_docstrings( - """ -The GPT2 Model transformer with a language modeling and a multiple-choice classification head on top e.g. for -RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the -input embeddings, the classification head takes as input the input of a specified classification token index in the -input sequence). -""", - GPT2_START_DOCSTRING, -) -class GPT2DoubleHeadsModel(GPT2PreTrainedModel): - _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"] - - def __init__(self, config): - super().__init__(config) - config.num_labels = 1 - self.transformer = GPT2Model(config) - self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) - self.multiple_choice_head = SequenceSummary(config) - - # Model parallel - self.model_parallel = False - self.device_map = None - - # Initialize weights and apply final processing - self.post_init() - - @add_start_docstrings(PARALLELIZE_DOCSTRING) - def parallelize(self, device_map=None): - self.device_map = ( - get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) - if device_map is None - else device_map - ) - assert_device_map(self.device_map, len(self.transformer.h)) - self.transformer.parallelize(self.device_map) - self.lm_head = self.lm_head.to(self.transformer.first_device) - self.multiple_choice_head = self.multiple_choice_head.to(self.transformer.first_device) - self.model_parallel = True - - @add_start_docstrings(DEPARALLELIZE_DOCSTRING) - def deparallelize(self): - self.transformer.deparallelize() - self.transformer = self.transformer.to("cpu") - self.lm_head = self.lm_head.to("cpu") - self.multiple_choice_head = self.multiple_choice_head.to("cpu") - self.model_parallel = False - torch.cuda.empty_cache() - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): - token_type_ids = kwargs.get("token_type_ids", None) - # only last token for inputs_ids if past is defined in kwargs - if past: - input_ids = input_ids[:, -1].unsqueeze(-1) - if token_type_ids is not None: - token_type_ids = token_type_ids[:, -1].unsqueeze(-1) - - attention_mask = kwargs.get("attention_mask", None) - position_ids = kwargs.get("position_ids", None) - - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past: - position_ids = position_ids[:, -1].unsqueeze(-1) - else: - position_ids = None - - return { - "input_ids": input_ids, - "past_key_values": past, - "use_cache": kwargs.get("use_cache"), - "position_ids": position_ids, - "attention_mask": attention_mask, - "token_type_ids": token_type_ids, - } - - @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - mc_token_ids: Optional[torch.LongTensor] = None, - labels: Optional[torch.LongTensor] = None, - mc_labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - **kwargs, - ) -> Union[Tuple, GPT2DoubleHeadsModelOutput]: - r""" - mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input): - Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) - - 1]`. - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set - `labels = input_ids`. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to - `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]` - mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*): - Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` - where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above) - - Return: - - Example: - - ```python - >>> import torch - >>> from transformers import GPT2Tokenizer, GPT2DoubleHeadsModel - - >>> tokenizer = GPT2Tokenizer.from_pretrained("gpt2") - >>> model = GPT2DoubleHeadsModel.from_pretrained("gpt2") - - >>> # Add a [CLS] to the vocabulary (we should train it also!) - >>> num_added_tokens = tokenizer.add_special_tokens({"cls_token": "[CLS]"}) - >>> # Update the model embeddings with the new vocabulary size - >>> embedding_layer = model.resize_token_embeddings(len(tokenizer)) - - >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"] - >>> encoded_choices = [tokenizer.encode(s) for s in choices] - >>> cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices] - - >>> input_ids = torch.tensor(encoded_choices).unsqueeze(0) # Batch size: 1, number of choices: 2 - >>> mc_token_ids = torch.tensor([cls_token_location]) # Batch size: 1 - - >>> outputs = model(input_ids, mc_token_ids=mc_token_ids) - >>> lm_logits = outputs.logits - >>> mc_logits = outputs.mc_logits - ```""" - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.transformer( - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = transformer_outputs[0] - - # Set device for model parallelism - if self.model_parallel: - torch.cuda.set_device(self.transformer.first_device) - hidden_states = hidden_states.to(self.lm_head.weight.device) - - lm_logits = self.lm_head(hidden_states) - mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1) - - mc_loss = None - if mc_labels is not None: - loss_fct = CrossEntropyLoss() - mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)) - lm_loss = None - if labels is not None: - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - loss_fct = CrossEntropyLoss() - lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - - if not return_dict: - output = (lm_logits, mc_logits) + transformer_outputs[1:] - if mc_loss is not None: - output = (mc_loss,) + output - return ((lm_loss,) + output) if lm_loss is not None else output - - return GPT2DoubleHeadsModelOutput( - loss=lm_loss, - mc_loss=mc_loss, - logits=lm_logits, - mc_logits=mc_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - @staticmethod - def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: - """ - This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or - [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct - beam_idx at every generation step. - """ - return tuple( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) - for layer_past in past - ) - - -@add_start_docstrings( - """ - The GPT2 Model transformer with a sequence classification head on top (linear layer). - - [`GPT2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models - (e.g. GPT-1) do. - - Since it does classification on the last token, it requires to know the position of the last token. If a - `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If - no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the - padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in - each row of the batch). - """, - GPT2_START_DOCSTRING, -) -class GPT2ForSequenceClassification(GPT2PreTrainedModel): - _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head.weight"] - - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.transformer = GPT2Model(config) - self.score = nn.Linear(config.n_embd, self.num_labels, bias=False) - - # Model parallel - self.model_parallel = False - self.device_map = None - - # Initialize weights and apply final processing - self.post_init() - - @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - processor_class=_TOKENIZER_FOR_DOC, - checkpoint="microsoft/DialogRPT-updown", - output_type=SequenceClassifierOutputWithPast, - config_class=_CONFIG_FOR_DOC, - expected_output="'LABEL_0'", - expected_loss=5.28, - ) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.transformer( - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) - - if input_ids is not None: - batch_size, sequence_length = input_ids.shape[:2] - else: - batch_size, sequence_length = inputs_embeds.shape[:2] - - assert ( - self.config.pad_token_id is not None or batch_size == 1 - ), "Cannot handle batch sizes > 1 if no padding token is defined." - if self.config.pad_token_id is None: - sequence_lengths = -1 - else: - if input_ids is not None: - sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1 - else: - sequence_lengths = -1 - logger.warning( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) - - pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] - - loss = None - if labels is not None: - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(pooled_logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(pooled_logits, labels) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - -@add_start_docstrings( - """ - GPT2 Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for - Named-Entity-Recognition (NER) tasks. - """, - GPT2_START_DOCSTRING, -) -class GPT2ForTokenClassification(GPT2PreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - - self.transformer = GPT2Model(config) - if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None: - classifier_dropout = config.classifier_dropout - elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None: - classifier_dropout = config.hidden_dropout - else: - classifier_dropout = 0.1 - self.dropout = nn.Dropout(classifier_dropout) - self.classifier = nn.Linear(config.hidden_size, config.num_labels) - - # Model parallel - self.model_parallel = False - self.device_map = None - - # Initialize weights and apply final processing - self.post_init() - - @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) - # fmt: off - @add_code_sample_docstrings( - processor_class=_TOKENIZER_FOR_DOC, - checkpoint="brad1141/gpt2-finetuned-comp2", - output_type=TokenClassifierOutput, - config_class=_CONFIG_FOR_DOC, - expected_loss=0.25, - expected_output=["Lead", "Lead", "Lead", "Position", "Lead", "Lead", "Lead", "Lead", "Lead", "Lead", "Lead", "Lead"], - ) - # fmt: on - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, TokenClassifierOutput]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.transformer( - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = transformer_outputs[0] - hidden_states = self.dropout(hidden_states) - logits = self.classifier(hidden_states) - - loss = None - if labels is not None: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - - if not return_dict: - output = (logits,) + transformer_outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) diff --git a/llm/llm/hf_flash_gpt_2.py b/llm/llm/hf_flash_gpt_2.py index a61a2d1f8..c67ac6d8b 100644 --- a/llm/llm/hf_flash_gpt_2.py +++ b/llm/llm/hf_flash_gpt_2.py @@ -21,7 +21,7 @@ import torch from einops import rearrange -from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func +from flash_attn.flash_attention import FlashAttention from torch import nn from transformers.models.gpt2.configuration_gpt2 import GPT2Config from transformers.models.gpt2.modeling_gpt2 import ( @@ -32,8 +32,13 @@ class GPT2FlashAttention(GPT2Attention): def __init__(self, config, is_cross_attention=False, layer_idx=None): super().__init__(config=config, is_cross_attention=is_cross_attention, layer_idx=layer_idx) + self.inner_attn = FlashAttention(softmax_scale=None, attention_dropout=config.attn_pdrop) if self.reorder_and_upcast_attn: - raise ValueError('GPT2FlashAttention does not support reorder_and_upcast_attn') + raise ValueError('GPT2FlashAttention does not support reorder_and_upcast_attn.') + if self.scale_attn_by_inverse_layer_idx: + raise ValueError('GPT2FlashAttention does not support scale_attn_by_inverse_layer_idx.') + if not self.scale_attn_weights: + raise ValueError('GPT2FlashAttention only supports scale_attn_weights=True.') def _attn(self, query, key, value, attention_mask=None, head_mask=None): # rearrange to flash attention form @@ -41,7 +46,6 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): value = rearrange(value, 'b h s d -> b s h d') query = rearrange(query, 'b h s d -> b s h d') - #assert query.dtype in [torch.float16, torch.bfloat16], f"{query.dtype}" # stack qkv = torch.stack([query,key,value], dim=2) From 472c576a1e93a3f7601853805ef8b7ec23cd05d6 Mon Sep 17 00:00:00 2001 From: Abhinav Venigalla Date: Mon, 17 Oct 2022 22:23:17 +0000 Subject: [PATCH 02/18] wip --- llm/yamls/debugging/noflash.yaml | 120 ++++++++++++++++++++++++++++++ llm/yamls/debugging/yesflash.yaml | 120 ++++++++++++++++++++++++++++++ 2 files changed, 240 insertions(+) create mode 100644 llm/yamls/debugging/noflash.yaml create mode 100644 llm/yamls/debugging/yesflash.yaml diff --git a/llm/yamls/debugging/noflash.yaml b/llm/yamls/debugging/noflash.yaml new file mode 100644 index 000000000..f4a36a30a --- /dev/null +++ b/llm/yamls/debugging/noflash.yaml @@ -0,0 +1,120 @@ +#data_remote: &data_remote ./my-copy-c4 +data_remote: &data_remote s3://crfm-pubmed/pubmed-randomized +#data_local: &data_local ./my-copy-c4 +data_local: &data_local /tmp/mds-cache/pubmed-randomized +max_seq_len: &max_seq_len 2048 +tokenizer_name: &tokenizer_name gpt2 + +# Model +model: + hf_config: hf_configs/noflash-gpt-125m-ctx-1024.json + +# Dataloaders +train_loader: + dataset: + name: streaming_pubmed + remote: *data_remote + local: *data_local + split: train + tokenizer_name: *tokenizer_name + max_seq_len: *max_seq_len + group_method: concat + shuffle: true + drop_last: true + num_workers: 8 + pin_memory: true + prefetch_factor: 2 + persistent_workers: true + timeout: 0 + +eval_loader: + dataset: + name: streaming_pubmed + remote: *data_remote + local: *data_local + split: val + tokenizer_name: *tokenizer_name + max_seq_len: *max_seq_len + group_method: truncate + shuffle: false + drop_last: false + num_workers: 8 + pin_memory: true + prefetch_factor: 2 + persistent_workers: true + timeout: 0 + +# Optimization +scheduler: + name: cosine_with_warmup + t_warmup: 100ba + alpha_f: 0.1 + +optimizer: + name: decoupled_adamw + lr: 6.0e-4 + betas: + - 0.9 + - 0.95 + eps: 1.0e-08 + weight_decay: 0.0 + +max_duration: 5000ba +eval_interval: 5000ba +global_train_batch_size: 256 +grad_clip_norm: 1.0 + +# System +seed: 17 +device_train_microbatch_size: 1 # Hitting memory limits with 16 when resuming a run :( +precision: bf16 + +# FSDP +# fsdp_config: +# sharding_strategy: FULL_SHARD +# min_params: 1e8 +# mixed_precision: default +# activation_checkpointing: false +# activation_cpu_offload: false +# verbose: true + +# Logging +callbacks: + speed_monitor: + window_size: 10 + lr_monitor: {} + # memory_monitor: {} + +loggers: + progress_bar: {} + # wandb: + # entity: stanford-mercury + # project: mosaic-gpt2 + # log_artifacts: true + + +# # # Checkpointing +# checkpoint_save_path: './{run_name}/checkpoints' +# checkpoint_save_interval: 4995ba +# num_checkpoints_to_keep: 1 # Important, this cleans up checkpoints saved to DISK + +# # WandB specific naming +# save_artifact_name: "{run_name}.pt" +# save_latest_artifact_name: "{run_name}.latest" + +# # Load from a WandB Artifact +# load_path: mosaicml-fsdp-flash-attention-demo-1.pt:ep0-ba30 +# load_object_store: +# wandb: +# entity: stanford-mercury +# project: mosaic-gpt2 + +# # Load from local filesystem +# load_path: ./gpt-125m/checkpoints/latest-rank{rank}.pt + +# # Load from remote object store +# load_path: gpt-125m/checkpoints/latest-rank{rank} +# load_object_store: +# s3: +# bucket: my-bucket +# prefix: my-folder diff --git a/llm/yamls/debugging/yesflash.yaml b/llm/yamls/debugging/yesflash.yaml new file mode 100644 index 000000000..f96ec0b3b --- /dev/null +++ b/llm/yamls/debugging/yesflash.yaml @@ -0,0 +1,120 @@ +#data_remote: &data_remote ./my-copy-c4 +data_remote: &data_remote s3://crfm-pubmed/pubmed-randomized +#data_local: &data_local ./my-copy-c4 +data_local: &data_local /tmp/mds-cache/pubmed-randomized +max_seq_len: &max_seq_len 2048 +tokenizer_name: &tokenizer_name gpt2 + +# Model +model: + hf_config: hf_configs/gpt-125m-ctx-1024.json + +# Dataloaders +train_loader: + dataset: + name: streaming_pubmed + remote: *data_remote + local: *data_local + split: train + tokenizer_name: *tokenizer_name + max_seq_len: *max_seq_len + group_method: concat + shuffle: true + drop_last: true + num_workers: 8 + pin_memory: true + prefetch_factor: 2 + persistent_workers: true + timeout: 0 + +eval_loader: + dataset: + name: streaming_pubmed + remote: *data_remote + local: *data_local + split: val + tokenizer_name: *tokenizer_name + max_seq_len: *max_seq_len + group_method: truncate + shuffle: false + drop_last: false + num_workers: 8 + pin_memory: true + prefetch_factor: 2 + persistent_workers: true + timeout: 0 + +# Optimization +scheduler: + name: cosine_with_warmup + t_warmup: 100ba + alpha_f: 0.1 + +optimizer: + name: decoupled_adamw + lr: 6.0e-4 + betas: + - 0.9 + - 0.95 + eps: 1.0e-08 + weight_decay: 0.0 + +max_duration: 5000ba +eval_interval: 5000ba +global_train_batch_size: 256 +grad_clip_norm: 1.0 + +# System +seed: 17 +device_train_microbatch_size: 1 # Hitting memory limits with 16 when resuming a run :( +precision: bf16 + +# FSDP +fsdp_config: + sharding_strategy: FULL_SHARD + min_params: 1e8 + mixed_precision: default + activation_checkpointing: false + activation_cpu_offload: false + verbose: true + +# Logging +callbacks: + speed_monitor: + window_size: 10 + lr_monitor: {} + # memory_monitor: {} + +loggers: + progress_bar: {} + # wandb: + # entity: stanford-mercury + # project: mosaic-gpt2 + # log_artifacts: true + + +# # # Checkpointing +# checkpoint_save_path: './{run_name}/checkpoints' +# checkpoint_save_interval: 4995ba +# num_checkpoints_to_keep: 1 # Important, this cleans up checkpoints saved to DISK + +# # WandB specific naming +# save_artifact_name: "{run_name}.pt" +# save_latest_artifact_name: "{run_name}.latest" + +# # Load from a WandB Artifact +# load_path: mosaicml-fsdp-flash-attention-demo-1.pt:ep0-ba30 +# load_object_store: +# wandb: +# entity: stanford-mercury +# project: mosaic-gpt2 + +# # Load from local filesystem +# load_path: ./gpt-125m/checkpoints/latest-rank{rank}.pt + +# # Load from remote object store +# load_path: gpt-125m/checkpoints/latest-rank{rank} +# load_object_store: +# s3: +# bucket: my-bucket +# prefix: my-folder From 40c55a402a72e6d23f10289876e132d13e5d24b5 Mon Sep 17 00:00:00 2001 From: Abhinav Venigalla Date: Tue, 18 Oct 2022 00:39:52 +0000 Subject: [PATCH 03/18] wip --- llm/convert_c4.py | 119 ---------- llm/hf_configs/noflash-gpt-125m-ctx-1024.json | 2 +- llm/llm/data.py | 208 ----------------- llm/llm/gpt.py | 21 +- llm/llm/gpt_old.py | 216 ------------------ llm/llm/hf_flash_gpt_2.py | 45 +--- llm/main.py | 50 ++-- llm/requirements.txt | 2 +- llm/yamls/debugging/noflash.yaml | 14 +- llm/yamls/debugging/yesflash.yaml | 16 +- 10 files changed, 59 insertions(+), 634 deletions(-) delete mode 100644 llm/convert_c4.py delete mode 100644 llm/llm/data.py delete mode 100644 llm/llm/gpt_old.py diff --git a/llm/convert_c4.py b/llm/convert_c4.py deleted file mode 100644 index 5b02e5160..000000000 --- a/llm/convert_c4.py +++ /dev/null @@ -1,119 +0,0 @@ -# Copyright 2022 MosaicML Composer authors -# SPDX-License-Identifier: Apache-2.0 - -"""C4 streaming dataset conversion scripts.""" - -import os -import random -from argparse import ArgumentParser, Namespace -from glob import glob -from itertools import islice -from typing import Any, Dict, Iterable, List, Tuple - -import datasets -import torch -from composer.datasets.streaming import StreamingDatasetWriter -from datasets import Dataset -from torch.utils.data import DataLoader, IterableDataset, get_worker_info - - -def parse_args() -> Namespace: - """Parse commandline arguments.""" - args = ArgumentParser() - args.add_argument('--out_root', type=str, required=True) - args.add_argument('--shard_size_limit', type=int, default=1 << 28) - args.add_argument('--tqdm', type=int, default=1) - args.add_argument('--splits', nargs='+', default=['train', 'val']) - - return args.parse_args() - - -def get(split: str) -> IterableDataset: - """Collect the samples for this dataset split. - - Args: - split (str): Split name. - - Returns: - An IterableDataset. - """ - - class ShardedC4(IterableDataset): - - def __init__(self): - self.dataset = datasets.load_dataset(path='c4', name='en', split=split, streaming=True) - - def num_shards(self): - return len(self.dataset._ex_iterable.kwargs['filepaths']) - - def __iter__(self): - worker_info = get_worker_info() - if worker_info: - num_workers = worker_info.num_workers - worker_id = worker_info.id - shards = self.dataset._ex_iterable.kwargs['filepaths'] - assert len(shards) % num_workers == 0 - self.dataset._ex_iterable.kwargs['filepaths'] = shards[worker_id::num_workers] - return iter(self.dataset) - - return ShardedC4() - - -def each(dataset: IterableDataset) -> Iterable[Dict[str, bytes]]: - """Generator over each dataset sample. - - Args: - samples (Dataset): A HF Dataset locally downloaded. - - Yields: - Sample dicts. - """ - num_workers = min(64, dataset.num_shards()) - batch_size = 512 - # If using multiple workers, configure each worker to prefetch as many samples as it can, up to the aggregate device batch size - # If not using workers, the torch DataLoader expects the default value for prefetch_factor, which non-intuitively must be 2. - prefetch_factor = max(1, 2 * batch_size // num_workers) if num_workers > 0 else 2 - - loader = DataLoader( - dataset=dataset, - sampler=None, - batch_size=batch_size, - num_workers=num_workers, - prefetch_factor=prefetch_factor, - ) - for batch in loader: - keys = list(batch.keys()) - current_bs = len(batch[keys[0]]) - for idx in range(current_bs): - yield {key: batch_values[idx].encode('utf-8') for key, batch_values in batch.items()} - - -def main(args: Namespace) -> None: - """Main: create C4 streaming dataset. - - Args: - args (Namespace): Commandline arguments. - """ - fields = ['text', 'timestamp', 'url'] - - for (split, split_new_name, expected_num_samples) in [ - ('train', 'train', 364868892), - ('validation', 'val', 364608), - ]: - # Only generate the splits requested - if split_new_name not in args.splits: - continue - - # Get dataset - dataset = get(split=split) - - # Write samples - with StreamingDatasetWriter(dirname=os.path.join(args.out_root, split_new_name), - fields=fields, - shard_size_limit=args.shard_size_limit, - compression=None) as out: - out.write_samples(samples=each(dataset), use_tqdm=bool(args.tqdm), total=expected_num_samples) - - -if __name__ == '__main__': - main(parse_args()) diff --git a/llm/hf_configs/noflash-gpt-125m-ctx-1024.json b/llm/hf_configs/noflash-gpt-125m-ctx-1024.json index 11e00e424..604207073 100644 --- a/llm/hf_configs/noflash-gpt-125m-ctx-1024.json +++ b/llm/hf_configs/noflash-gpt-125m-ctx-1024.json @@ -1,7 +1,7 @@ { "activation_function": "gelu", "architectures": [ - "GPT2FlashLMHeadModel" + "GPT2LMHeadModel" ], "attn_pdrop": 0.1, "bos_token_id": 50256, diff --git a/llm/llm/data.py b/llm/llm/data.py deleted file mode 100644 index ed2f2aabd..000000000 --- a/llm/llm/data.py +++ /dev/null @@ -1,208 +0,0 @@ -# Copyright 2022 MosaicML Composer authors -# SPDX-License-Identifier: Apache-2.0 - -""" -Build a StreamingC4 dataset and dataloader for training. -""" - -import os -import sys -from itertools import islice -from typing import Any, Dict, Iterator, Mapping, Optional - -import transformers -from composer.datasets.streaming import StreamingDataset -from torch.utils.data import DataLoader - - -class StreamingC4(StreamingDataset): - """ - Implementation of the C4 (Colossal Cleaned Common Crawl) dataset using StreamingDataset V1. - - Args: - remote (str): Remote directory (S3 or local filesystem) where dataset is stored. - local (str): Local filesystem directory where dataset is cached during operation. - split (str): The dataset split to use, either 'train' or 'val'. - shuffle (bool): Whether to shuffle the samples in this dataset. - tokenizer_name (str): The name of the HuggingFace tokenizer to use to tokenize samples. - max_seq_len (int): The max sequence length of each token sample. - group_method (str): How to group text samples into token samples. Supports 'truncate' or 'concat'. - max_retries (int): Number of download re-attempts before giving up. Default: 2. - timeout (float): How long to wait for shard to download before raising an exception. Default: 120 sec. - batch_size (Optional[int]): Hint batch_size that will be used on each device's DataLoader. Default: ``None``. - """ - - def __init__(self, - remote: str, - local: str, - split: str, - shuffle: bool, - tokenizer_name: str, - max_seq_len: int, - group_method: str = 'truncate', - max_retries: int = 2, - timeout: float = 120, - batch_size: Optional[int] = None): - # Validation - if split not in ['train', 'val']: - raise ValueError(f"split='{split}' must be one of ['train', 'val'].") - if group_method not in ['truncate', 'concat']: - raise ValueError(f"group_method='{group_method}' must be one of ['truncate', 'concat'].") - - # Build StreamingDataset - decoders = { - 'text': self._decode, - 'timestamp': self._decode, - 'url': self._decode, - } - super().__init__(remote=os.path.join(remote, split), - local=os.path.join(local, split), - shuffle=shuffle, - decoders=decoders, - max_retries=max_retries, - timeout=timeout, - batch_size=batch_size) - self.tokenizer_name = tokenizer_name - self.max_seq_len = max_seq_len - self.group_method = group_method - - # Build tokenizer - self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.tokenizer_name) - if self.tokenizer.pad_token is None: - # Some tokenizers (e.g. GPT2 tokenizer) have no padding token which causes bugs - self.tokenizer.pad_token = self.tokenizer.eos_token - # suppress warnings when using group_method='concat' and no truncation - self.tokenizer.model_max_length = int(1e30) - - # How to decode binary data from .mds files to python strings - def _decode(self, data: bytes) -> str: - return data.decode('utf-8') - - # How to tokenize a text sample to a token sample - def _tokenize(self, text_sample): - if self.group_method == 'truncate': - truncation = True - padding = 'max_length' - max_length = self.max_seq_len - elif self.group_method == 'concat': - truncation = False - padding = False - max_length = None - else: - raise ValueError(f"Got unknown group_method='{self.group_method}'.") - return self.tokenizer(text_sample['text'], truncation=truncation, padding=padding, max_length=max_length) - - # How to process a sample - def __getitem__(self, idx: int) -> Dict[str, Any]: - text_sample = super().__getitem__(idx) - token_sample = self._tokenize(text_sample) - return token_sample - - # Define iterable over samples - # Usually this can be left alone and inherited directly from super() class StreamingDataset, but concatenating samples is custom behavior. - # If group_method=='truncate', we simply return the token sample. - # If group_method=='concat', then we keep fetching token samples until we fill up max_seq_len. - def __iter__(self) -> Iterator[Any]: - if self.group_method == 'truncate': - iterator = super().__iter__() - yield from iterator - - elif self.group_method == 'concat': - buffer = {} - while True: - iterator = super().__iter__() - for sample in iterator: - - for k, v in sample.items(): - buffer[k] = buffer.get(k, []) + v + [self.tokenizer.eos_token_id] - if len(buffer['input_ids']) >= self.max_seq_len: - concat_sample = {} - for k, v in buffer.items(): - concat_sample[k] = v[:self.max_seq_len] - buffer[k] = v[self.max_seq_len:] - yield concat_sample - else: - raise ValueError(f"Got unknown group_method='{self.group_method}'.") - - # Define length - # Usually this can be left alone and inherited directly from super() class StreamingDataset, but concatenating samples is custom behavior. - # If group_method=='truncate', we simply return the # samples. - # If group_method=='concat', we repeat forever, and we don't have a defined length. - def __len__(self) -> int: - if self.group_method == 'truncate': - return super().__len__() - elif self.group_method == 'concat': - return None - else: - raise ValueError(f"Got unknown group_method='{self.group_method}'.") - - -def build_dataloader(cfg: Mapping[str, Any], device_batch_size: int): - - if cfg.dataset.name == 'streaming_c4': - dataset = StreamingC4(split=cfg.dataset.split, - remote=cfg.dataset.remote, - local=cfg.dataset.local, - shuffle=cfg.dataset.shuffle, - tokenizer_name=cfg.dataset.tokenizer_name, - max_seq_len=cfg.dataset.max_seq_len, - group_method=cfg.dataset.group_method, - batch_size=device_batch_size) - else: - raise ValueError(f'Not sure how to build dataset={cfg.dataset.name}') - - collate_fn = transformers.DataCollatorForLanguageModeling( - tokenizer=dataset.tokenizer, mlm=False) - - return DataLoader( - dataset, - collate_fn=collate_fn, - batch_size=device_batch_size, - drop_last=cfg.drop_last, - num_workers=cfg.num_workers, - pin_memory=cfg.pin_memory, - prefetch_factor=cfg.prefetch_factor, - persistent_workers=cfg.persistent_workers, - timeout=cfg.timeout, - ) - -# Helpful to test if your dataloader is working locally -# Run `python data.py [remote] [local, optional]` and verify that batches are printed out -if __name__ == '__main__': - remote = sys.argv[1] - if len(sys.argv) > 2: - local = sys.argv[2] - else: - local = remote - print (f'Reading val split from {remote} -> {local}') - - batch_size = 2 - dataset = StreamingC4(split='val', - remote=remote, - local=local, - shuffle=False, - tokenizer_name='gpt2', - max_seq_len=32, - group_method='concat', - batch_size=batch_size) - - collate_fn = transformers.DataCollatorForLanguageModeling( - tokenizer=dataset.tokenizer, mlm=False) - - loader = DataLoader( - dataset, - collate_fn=collate_fn, - batch_size=batch_size, - drop_last=False, - num_workers=4, - ) - - for batch_ix, batch in enumerate(islice(loader, 5)): - print('\n') - print ('#'*20, f'Batch {batch_ix}', '#'*20) - for k, v in batch.items(): - print (k, v.shape, v.dtype) - for sample_ix, token_sample in enumerate(batch['input_ids']): - print ('-'*20, f' Sample {sample_ix} ', '-'*20) - print (dataset.tokenizer.decode(token_sample)) - diff --git a/llm/llm/gpt.py b/llm/llm/gpt.py index 7850d1825..0ea07ea83 100644 --- a/llm/llm/gpt.py +++ b/llm/llm/gpt.py @@ -19,9 +19,23 @@ from .hf_flash_gpt_2 import GPT2FlashLMHeadModel +def prepare_hf_gpt2_model_for_fsdp(model): + # Special Case! When using the LMHeadModel, the weights of the self.lm_head and self.transformer.wte are tied. + # This tying occurs inside the `self.post_init()` function call above. + # This is a hurdle for FSDP because they need to be in the same FSDP block + # These lines ensures that both modules stay together in the top-most block + model.transformer._fsdp_wrap = False + model.transformer.wte._fsdp_wrap = False + model.lm_head._fsdp_wrap = False + + # FSDP Wrap and Activation Checkpoint every GPT2Block + for block in model.transformer.h: + block._fsdp_wrap = True + block._activation_checkpointing = True + class ComposerGPT(ComposerModel): - def __init__(self, cfg, device='meta'): + def __init__(self, cfg): super().__init__() # load GPT2 config from standard HF model config json hf_config = GPT2Config.from_json_file(cfg.hf_config) @@ -31,9 +45,12 @@ def __init__(self, cfg, device='meta'): self.model = GPT2LMHeadModel(hf_config) elif model_class == 'GPT2FlashLMHeadModel': self.model = GPT2FlashLMHeadModel(hf_config) - self.model.to(device) else: raise ValueError(f'Not sure how to build model_class={model_class}') + + # Tag layers to make the model ready for FSDP + prepare_hf_gpt2_model_for_fsdp(self.model) + self.train_metrics = { 'LanguageCrossEntropy': LanguageCrossEntropy(hf_config.vocab_size), 'Perplexity': Perplexity(), diff --git a/llm/llm/gpt_old.py b/llm/llm/gpt_old.py deleted file mode 100644 index af5c37638..000000000 --- a/llm/llm/gpt_old.py +++ /dev/null @@ -1,216 +0,0 @@ -# Copyright 2022 MosaicML Composer authors -# SPDX-License-Identifier: Apache-2.0 - -""" -A simple, flexible implementation of a GPT model. -Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py -""" - -import math -from typing import Any, Mapping - -import torch -import torch.nn as nn -import torch.nn.functional as F -from composer.metrics.nlp import LanguageCrossEntropy, Perplexity -from composer.models.base import ComposerModel -from flash_attn.flash_attention import FlashMHA - - -class TorchCausalAttention(nn.Module): - def __init__(self, cfg: Mapping[str, Any], device: str = None): - super().__init__() - self.mha = nn.MultiheadAttention( - embed_dim=cfg.d_model, - num_heads=cfg.n_heads, - dropout=cfg.attn_pdrop, - bias=True, - batch_first=True, - device=device, - ) - self.register_buffer( - "mask", torch.tril(torch.ones(cfg.max_seq_len, cfg.max_seq_len))) - self.mha.out_proj._is_residual = True - - def forward(self, x, key_padding_mask): - return self.mha(x, x, x, attn_mask=self.mask, need_weights=False) - - -class FlashCausalAttention(nn.Module): - def __init__(self, cfg: Mapping[str, Any], device: str = None): - super().__init__() - self.mha = FlashMHA( - embed_dim=cfg.d_model, - num_heads=cfg.n_heads, - attention_dropout=cfg.attn_pdrop, - bias=True, - batch_first=True, - causal=True, - device=device, - ) - self.mha.out_proj._is_residual = True - - def forward(self, x, key_padding_mask): - return self.mha(x, - key_padding_mask=key_padding_mask, - need_weights=False) - - -class GPTMLP(nn.Module): - def __init__(self, cfg: Mapping[str, Any], device: str = None): - super().__init__() - self.mlp_up = nn.Linear(cfg.d_model, - cfg.mlp_ratio * cfg.d_model, - device=device) - self.mlp_act = nn.GELU(approximate='none') - self.mlp_down = nn.Linear(cfg.mlp_ratio * cfg.d_model, - cfg.d_model, - device=device) - self.mlp_down._is_residual = True - - def forward(self, x): - return self.mlp_down(self.mlp_act(self.mlp_up(x))) - - -class GPTBlock(nn.Module): - def __init__(self, cfg: Mapping[str, Any], device: str = None): - super().__init__() - self.ln_1 = nn.LayerNorm(cfg.d_model, device=device) - if cfg.attn_impl == 'torch': - self.causal_attn = TorchCausalAttention(cfg, device) - elif cfg.attn_impl == 'flash': - self.causal_attn = FlashCausalAttention(cfg, device) - else: - raise ValueError(f'Unknown attn_impl={cfg.attn_impl}') - self.ln_2 = nn.LayerNorm(cfg.d_model, device=device) - self.mlp = GPTMLP(cfg, device=device) - self.resid_attn_dropout = nn.Dropout(cfg.resid_pdrop) - self.resid_mlp_dropout = nn.Dropout(cfg.resid_pdrop) - - def forward(self, - x: torch.Tensor, - key_padding_mask: torch.ByteTensor = None) -> torch.Tensor: - a = self.ln_1(x) - b, _ = self.causal_attn(a, key_padding_mask) - x = x + self.resid_attn_dropout(b) - m = self.ln_2(x) - n = self.mlp(m) - x = x + self.resid_mlp_dropout(n) - return x - - -class GPT(nn.Module): - def __init__(self, cfg: Mapping[str, Any], device: str = 'meta'): - super().__init__() - self.cfg = cfg - self.transformer = nn.ModuleDict( - dict( - wte=nn.Embedding(cfg.vocab_size, cfg.d_model, device=device), - wpe=nn.Embedding(cfg.max_seq_len, cfg.d_model, device=device), - emb_drop=nn.Dropout(cfg.emb_pdrop), - blocks=nn.ModuleList([ - GPTBlock(cfg, device=device) for _ in range(cfg.n_layers) - ]), - ln_f=nn.LayerNorm(cfg.d_model, device=device), - )) - self.lm_head = nn.Linear(cfg.d_model, - cfg.vocab_size, - bias=False, - device=device) - - if device != 'meta': - self.apply(self.param_init_fn) - - def forward(self, - input_ids: torch.LongTensor, - key_padding_mask: torch.ByteTensor = None): - _, S = input_ids.size() - assert ( - S <= self.cfg.max_seq_len - ), f"Cannot forward input with seq_len={S}, this model only supports seq_len<={self.cfg.max_seq_len}" - pos = torch.arange(0, S, dtype=torch.long, - device=input_ids.device).unsqueeze(0) - - tok_emb = self.transformer.wte(input_ids) - pos_emb = self.transformer.wpe(pos) - x = self.transformer.emb_drop(tok_emb + pos_emb) - for block in self.transformer.blocks: - x = block(x, key_padding_mask) - x = self.transformer.ln_f(x) - logits = self.lm_head(x) - return logits - - # Param Initialization, needed for device='meta' fast initialization - def param_init_fn(self, module): - # Linear - if isinstance(module, nn.Linear): - torch.nn.init.normal_(module.weight, - mean=0.0, - std=self.cfg.init_std) - if module.bias is not None: - torch.nn.init.zeros_(module.bias) - - if getattr(module, '_is_residual', False): - module.weight.data.normal_( - mean=0.0, - std=(self.cfg.init_std / math.sqrt(2 * self.cfg.n_layers))) - - # Embedding - if isinstance(module, nn.Embedding): - torch.nn.init.normal_(module.weight, - mean=0.0, - std=self.cfg.init_std) - - # LayerNorm - if isinstance(module, nn.LayerNorm): - torch.nn.init.zeros_(module.bias) - torch.nn.init.ones_(module.weight) - - # FSDP Wrap function - def fsdp_wrap_fn(self, module): - return isinstance(module, GPTBlock) - - # Activation Checkpointing - def activation_checkpointing_fn(self, module): - return isinstance(module, GPTBlock) - - -class ComposerGPT(ComposerModel): - - def __init__(self, cfg, device='meta'): - super().__init__() - self.model = GPT(cfg, device=device) - self.train_metrics = { - 'LanguageCrossEntropy': LanguageCrossEntropy(cfg.vocab_size), - 'Perplexity': Perplexity(), - } - self.eval_metrics = { - 'LanguageCrossEntropy': LanguageCrossEntropy(cfg.vocab_size), - 'Perplexity': Perplexity(), - } - - def get_targets(self, batch): - targets = torch.roll(batch["labels"], shifts=-1) - targets[:, -1] = -100 - return targets - - def forward(self, batch): - return self.model(batch['input_ids'], - key_padding_mask=batch['attention_mask'].bool()) - - def eval_forward(self, batch, outputs=None): - return outputs if outputs is not None else self.forward(batch) - - def loss(self, outputs, batch): - targets = self.get_targets(batch) - return F.cross_entropy(outputs.view(-1, outputs.size(-1)), - targets.view(-1), - ignore_index=-100) - - def get_metrics(self, is_train=False): - return self.train_metrics if is_train else self.eval_metrics - - def update_metric(self, batch, outputs, metric): - outputs = outputs.view(-1, outputs.size(-1)) - targets = self.get_targets(batch).view(-1) - metric.update(outputs, targets) diff --git a/llm/llm/hf_flash_gpt_2.py b/llm/llm/hf_flash_gpt_2.py index c67ac6d8b..4d4209aa0 100644 --- a/llm/llm/hf_flash_gpt_2.py +++ b/llm/llm/hf_flash_gpt_2.py @@ -41,35 +41,21 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): raise ValueError('GPT2FlashAttention only supports scale_attn_weights=True.') def _attn(self, query, key, value, attention_mask=None, head_mask=None): + if head_mask is not None: + raise ValueError('GPT2FlashAttention._attn does not support "head_mask"') # rearrange to flash attention form key = rearrange(key, 'b h s d -> b s h d') value = rearrange(value, 'b h s d -> b s h d') query = rearrange(query, 'b h s d -> b s h d') - # stack qkv = torch.stack([query,key,value], dim=2) - #qkv = torch.tensor(qkv,dtype=torch.bfloat16) assert qkv.dtype in [torch.float16, torch.bfloat16] - # flash attention logic - batch_size = qkv.shape[0] - seqlen = qkv.shape[1] - num_heads = qkv.shape[3] - dk = qkv.shape[4] - dk_per_head = int(dk)/int(num_heads) - qkv = rearrange(qkv, 'b s ... -> (b s) ...') - max_s = seqlen - cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, device=qkv.device) - attn_pdrop = 0.1 - softmax_scale = 1/float(math.sqrt(dk)) - output = flash_attn_unpadded_qkvpacked_func( - qkv, cu_seqlens, max_s, attn_pdrop, - softmax_scale=softmax_scale, causal=True - ) - output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) + output, attn_weights = self.inner_attn(qkv, key_padding_mask=attention_mask, + need_weights=False, causal=True) + output = rearrange(output, 'b s h d -> b h s d') - #output = torch.tensor(output, dtype=torch.float32) return output, None @@ -125,24 +111,3 @@ def __init__(self, config): # Initialize weights and apply final processing self.post_init() - - # Special Case! When using the LMHeadModel, the weights of the self.lm_head and self.transformer.wte are tied. - # This tying occurs inside the `self.post_init()` function call above. - # This is a hurdle for FSDP because they need to be in the same FSDP block - # These lines ensures that both modules stay together in the top-most block - self.transformer._fsdp_wrap = False - self.transformer.wte._fsdp_wrap = False - self.lm_head._fsdp_wrap = False - - # Meta tensor param init fn - def param_init_fn(self, module): - if isinstance(module, GPT2LMHeadModel): - module.post_init() - - # FSDP Wrap function - def fsdp_wrap_fn(self, module): - return isinstance(module, GPT2Block) - - # Activation Checkpointing - def activation_checkpointing_fn(self, module): - return isinstance(module, GPT2Block) diff --git a/llm/main.py b/llm/main.py index 0bca2fccc..e00fb2fab 100644 --- a/llm/main.py +++ b/llm/main.py @@ -6,7 +6,7 @@ from composer import Trainer from composer.callbacks import LRMonitor, MemoryMonitor, SpeedMonitor -from composer.loggers import ObjectStoreLogger, ProgressBarLogger, WandBLogger +from composer.loggers import ProgressBarLogger, WandBLogger from composer.optim import DecoupledAdamW from composer.optim.scheduler import (ConstantWithWarmupScheduler, CosineAnnealingWithWarmupScheduler) @@ -26,21 +26,9 @@ def build_logger(name, kwargs): ) elif name == 'wandb': return WandBLogger(**kwargs) - elif name == 's3': - object_store_logger = ObjectStoreLogger( - object_store_cls=S3ObjectStore, - object_store_kwargs=kwargs, - ) - return object_store_logger else: raise ValueError(f'Not sure how to build logger: {name}') -def build_object_store(name, kwargs): - if name == 's3': - return S3ObjectStore(**kwargs) - else: - raise ValueError(f'Not sure how to build object store: {name}') - def build_callback(name, kwargs): if name == 'lr_monitor': return LRMonitor() @@ -99,8 +87,7 @@ def main(cfg): # Build Model # For fast initialization, use `meta` device print('Initializing model...') - device = 'meta' if fsdp_config else 'cuda' - model = ComposerGPT(cfg=cfg.model, device=device) + model = ComposerGPT(cfg=cfg.model) n_params = sum(p.numel() for p in model.parameters()) print(f'{n_params=:.2e}') @@ -131,15 +118,15 @@ def main(cfg): # Callbacks callbacks = [build_callback(name, callback_cfg) for name, callback_cfg in cfg.callbacks.items()] - # (Optional) Load object store - load_object_store = cfg.get('load_object_store', None) - if load_object_store is not None: - name = list(load_object_store.keys())[0] - kwargs = load_object_store[name] - if name in ['s3']: - load_object_store = build_object_store(name, kwargs) - elif name in ['wandb']: - load_object_store = build_logger(name, kwargs) + # # (Optional) Load object store + # load_object_store = cfg.get('load_object_store', None) + # if load_object_store is not None: + # name = list(load_object_store.keys())[0] + # kwargs = load_object_store[name] + # if name in ['s3']: + # load_object_store = build_object_store(name, kwargs) + # elif name in ['wandb']: + # load_object_store = build_logger(name, kwargs) # Build the Trainer trainer = Trainer( @@ -158,14 +145,13 @@ def main(cfg): grad_clip_norm=cfg.grad_clip_norm, grad_accum=device_train_grad_accum, fsdp_config=fsdp_config, - checkpoint_save_path=cfg.get('checkpoint_save_path', None), - checkpoint_save_interval=cfg.get('checkpoint_save_interval', '1000ba'), - num_checkpoints_to_keep=cfg.get('num_checkpoints_to_keep', -1), - save_artifact_name=cfg.get('save_artifact_name', '{run_name}/checkpoints/ep{epoch}-ba{batch}-rank{rank}.pt'), - save_latest_artifact_name=cfg.get('save_latest_artifact_name', '{run_name}/checkpoints/latest-rank{rank}'), - load_path=cfg.get('load_path', None), - load_object_store=load_object_store, - load_weights_only=cfg.get('load_weights_only', False), + # checkpoint_save_path=cfg.get('checkpoint_save_path', None), + # checkpoint_save_interval=cfg.get('checkpoint_save_interval', '1000ba'), + # num_checkpoints_to_keep=cfg.get('num_checkpoints_to_keep', -1), + # load_path=cfg.get('load_path', None), + # load_object_store=load_object_store, + # load_weights_only=cfg.get('load_weights_only', False), + eval_subset_num_batches=100, ) print("Logging config...") diff --git a/llm/requirements.txt b/llm/requirements.txt index a0833e46b..d9d73e972 100644 --- a/llm/requirements.txt +++ b/llm/requirements.txt @@ -1,4 +1,4 @@ -mosaicml[streaming] @ git+https://github.com/mosaicml/composer@fsdp-alpha +mosaicml[streaming] @ git+https://github.com/mosaicml/composer@feb856fd7ac05659cff3d1aa894155310ceb322a flash_attn @ git+https://github.com/HazyResearch/flash-attention.git@main transformers==4.21.3 datasets==2.4.0 diff --git a/llm/yamls/debugging/noflash.yaml b/llm/yamls/debugging/noflash.yaml index f4a36a30a..0b8b2f5ac 100644 --- a/llm/yamls/debugging/noflash.yaml +++ b/llm/yamls/debugging/noflash.yaml @@ -2,7 +2,7 @@ data_remote: &data_remote s3://crfm-pubmed/pubmed-randomized #data_local: &data_local ./my-copy-c4 data_local: &data_local /tmp/mds-cache/pubmed-randomized -max_seq_len: &max_seq_len 2048 +max_seq_len: &max_seq_len 1024 tokenizer_name: &tokenizer_name gpt2 # Model @@ -35,7 +35,7 @@ eval_loader: split: val tokenizer_name: *tokenizer_name max_seq_len: *max_seq_len - group_method: truncate + group_method: concat shuffle: false drop_last: false num_workers: 8 @@ -60,20 +60,20 @@ optimizer: weight_decay: 0.0 max_duration: 5000ba -eval_interval: 5000ba +eval_interval: 20ba global_train_batch_size: 256 grad_clip_norm: 1.0 # System seed: 17 -device_train_microbatch_size: 1 # Hitting memory limits with 16 when resuming a run :( +device_train_microbatch_size: 16 # Hitting memory limits with 16 when resuming a run :( precision: bf16 # FSDP # fsdp_config: # sharding_strategy: FULL_SHARD # min_params: 1e8 -# mixed_precision: default +# mixed_precision: full # activation_checkpointing: false # activation_cpu_offload: false # verbose: true @@ -87,9 +87,9 @@ callbacks: loggers: progress_bar: {} - # wandb: + wandb: # entity: stanford-mercury - # project: mosaic-gpt2 + project: composer-gpt2-debug # log_artifacts: true diff --git a/llm/yamls/debugging/yesflash.yaml b/llm/yamls/debugging/yesflash.yaml index f96ec0b3b..031c4b769 100644 --- a/llm/yamls/debugging/yesflash.yaml +++ b/llm/yamls/debugging/yesflash.yaml @@ -2,7 +2,7 @@ data_remote: &data_remote s3://crfm-pubmed/pubmed-randomized #data_local: &data_local ./my-copy-c4 data_local: &data_local /tmp/mds-cache/pubmed-randomized -max_seq_len: &max_seq_len 2048 +max_seq_len: &max_seq_len 1024 tokenizer_name: &tokenizer_name gpt2 # Model @@ -35,7 +35,7 @@ eval_loader: split: val tokenizer_name: *tokenizer_name max_seq_len: *max_seq_len - group_method: truncate + group_method: concat shuffle: false drop_last: false num_workers: 8 @@ -60,20 +60,20 @@ optimizer: weight_decay: 0.0 max_duration: 5000ba -eval_interval: 5000ba +eval_interval: 20ba global_train_batch_size: 256 grad_clip_norm: 1.0 # System seed: 17 -device_train_microbatch_size: 1 # Hitting memory limits with 16 when resuming a run :( +device_train_microbatch_size: 16 # Hitting memory limits with 16 when resuming a run :( precision: bf16 # FSDP fsdp_config: - sharding_strategy: FULL_SHARD + sharding_strategy: SHARD_GRAD_OP min_params: 1e8 - mixed_precision: default + mixed_precision: full activation_checkpointing: false activation_cpu_offload: false verbose: true @@ -87,9 +87,9 @@ callbacks: loggers: progress_bar: {} - # wandb: + wandb: # entity: stanford-mercury - # project: mosaic-gpt2 + project: composer-gpt2-debug # log_artifacts: true From 0326acdcf4ace63bd44ee2fef68ea34948f7f8df Mon Sep 17 00:00:00 2001 From: Abhinav Venigalla Date: Tue, 18 Oct 2022 01:10:20 +0000 Subject: [PATCH 04/18] upgrade trainer args --- llm/main.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/llm/main.py b/llm/main.py index e00fb2fab..493834e0a 100644 --- a/llm/main.py +++ b/llm/main.py @@ -145,9 +145,11 @@ def main(cfg): grad_clip_norm=cfg.grad_clip_norm, grad_accum=device_train_grad_accum, fsdp_config=fsdp_config, - # checkpoint_save_path=cfg.get('checkpoint_save_path', None), - # checkpoint_save_interval=cfg.get('checkpoint_save_interval', '1000ba'), - # num_checkpoints_to_keep=cfg.get('num_checkpoints_to_keep', -1), + save_folder=cfg.get('save_folder', None), + save_filename=cfg.get('save_filename', None), + save_latest_filename=cfg.get('save_latest_filename', None), + save_interval=cfg.get('checkpoint_save_interval', '1000ba'), + save_num_checkpoints_to_keep=cfg.get('save_num_checkpoints_to_keep', -1), # load_path=cfg.get('load_path', None), # load_object_store=load_object_store, # load_weights_only=cfg.get('load_weights_only', False), From 5b693b9a618c1791ad68fbd805933786c846da0d Mon Sep 17 00:00:00 2001 From: Abhinav Venigalla Date: Tue, 18 Oct 2022 02:30:56 +0000 Subject: [PATCH 05/18] wip --- llm/main.py | 15 +--- llm/yamls/debugging/noflash.yaml | 120 ------------------------------ llm/yamls/debugging/yesflash.yaml | 120 ------------------------------ 3 files changed, 2 insertions(+), 253 deletions(-) delete mode 100644 llm/yamls/debugging/noflash.yaml delete mode 100644 llm/yamls/debugging/yesflash.yaml diff --git a/llm/main.py b/llm/main.py index 493834e0a..dad64efc4 100644 --- a/llm/main.py +++ b/llm/main.py @@ -118,16 +118,6 @@ def main(cfg): # Callbacks callbacks = [build_callback(name, callback_cfg) for name, callback_cfg in cfg.callbacks.items()] - # # (Optional) Load object store - # load_object_store = cfg.get('load_object_store', None) - # if load_object_store is not None: - # name = list(load_object_store.keys())[0] - # kwargs = load_object_store[name] - # if name in ['s3']: - # load_object_store = build_object_store(name, kwargs) - # elif name in ['wandb']: - # load_object_store = build_logger(name, kwargs) - # Build the Trainer trainer = Trainer( run_name=cfg.get('run_name', os.environ['COMPOSER_RUN_NAME']), @@ -148,12 +138,11 @@ def main(cfg): save_folder=cfg.get('save_folder', None), save_filename=cfg.get('save_filename', None), save_latest_filename=cfg.get('save_latest_filename', None), - save_interval=cfg.get('checkpoint_save_interval', '1000ba'), + save_interval=cfg.get('save_interval', '1000ba'), save_num_checkpoints_to_keep=cfg.get('save_num_checkpoints_to_keep', -1), # load_path=cfg.get('load_path', None), - # load_object_store=load_object_store, # load_weights_only=cfg.get('load_weights_only', False), - eval_subset_num_batches=100, + eval_subset_num_batches=cfg.get('eval_subset_num_batches', 5000), ) print("Logging config...") diff --git a/llm/yamls/debugging/noflash.yaml b/llm/yamls/debugging/noflash.yaml deleted file mode 100644 index 0b8b2f5ac..000000000 --- a/llm/yamls/debugging/noflash.yaml +++ /dev/null @@ -1,120 +0,0 @@ -#data_remote: &data_remote ./my-copy-c4 -data_remote: &data_remote s3://crfm-pubmed/pubmed-randomized -#data_local: &data_local ./my-copy-c4 -data_local: &data_local /tmp/mds-cache/pubmed-randomized -max_seq_len: &max_seq_len 1024 -tokenizer_name: &tokenizer_name gpt2 - -# Model -model: - hf_config: hf_configs/noflash-gpt-125m-ctx-1024.json - -# Dataloaders -train_loader: - dataset: - name: streaming_pubmed - remote: *data_remote - local: *data_local - split: train - tokenizer_name: *tokenizer_name - max_seq_len: *max_seq_len - group_method: concat - shuffle: true - drop_last: true - num_workers: 8 - pin_memory: true - prefetch_factor: 2 - persistent_workers: true - timeout: 0 - -eval_loader: - dataset: - name: streaming_pubmed - remote: *data_remote - local: *data_local - split: val - tokenizer_name: *tokenizer_name - max_seq_len: *max_seq_len - group_method: concat - shuffle: false - drop_last: false - num_workers: 8 - pin_memory: true - prefetch_factor: 2 - persistent_workers: true - timeout: 0 - -# Optimization -scheduler: - name: cosine_with_warmup - t_warmup: 100ba - alpha_f: 0.1 - -optimizer: - name: decoupled_adamw - lr: 6.0e-4 - betas: - - 0.9 - - 0.95 - eps: 1.0e-08 - weight_decay: 0.0 - -max_duration: 5000ba -eval_interval: 20ba -global_train_batch_size: 256 -grad_clip_norm: 1.0 - -# System -seed: 17 -device_train_microbatch_size: 16 # Hitting memory limits with 16 when resuming a run :( -precision: bf16 - -# FSDP -# fsdp_config: -# sharding_strategy: FULL_SHARD -# min_params: 1e8 -# mixed_precision: full -# activation_checkpointing: false -# activation_cpu_offload: false -# verbose: true - -# Logging -callbacks: - speed_monitor: - window_size: 10 - lr_monitor: {} - # memory_monitor: {} - -loggers: - progress_bar: {} - wandb: - # entity: stanford-mercury - project: composer-gpt2-debug - # log_artifacts: true - - -# # # Checkpointing -# checkpoint_save_path: './{run_name}/checkpoints' -# checkpoint_save_interval: 4995ba -# num_checkpoints_to_keep: 1 # Important, this cleans up checkpoints saved to DISK - -# # WandB specific naming -# save_artifact_name: "{run_name}.pt" -# save_latest_artifact_name: "{run_name}.latest" - -# # Load from a WandB Artifact -# load_path: mosaicml-fsdp-flash-attention-demo-1.pt:ep0-ba30 -# load_object_store: -# wandb: -# entity: stanford-mercury -# project: mosaic-gpt2 - -# # Load from local filesystem -# load_path: ./gpt-125m/checkpoints/latest-rank{rank}.pt - -# # Load from remote object store -# load_path: gpt-125m/checkpoints/latest-rank{rank} -# load_object_store: -# s3: -# bucket: my-bucket -# prefix: my-folder diff --git a/llm/yamls/debugging/yesflash.yaml b/llm/yamls/debugging/yesflash.yaml deleted file mode 100644 index 031c4b769..000000000 --- a/llm/yamls/debugging/yesflash.yaml +++ /dev/null @@ -1,120 +0,0 @@ -#data_remote: &data_remote ./my-copy-c4 -data_remote: &data_remote s3://crfm-pubmed/pubmed-randomized -#data_local: &data_local ./my-copy-c4 -data_local: &data_local /tmp/mds-cache/pubmed-randomized -max_seq_len: &max_seq_len 1024 -tokenizer_name: &tokenizer_name gpt2 - -# Model -model: - hf_config: hf_configs/gpt-125m-ctx-1024.json - -# Dataloaders -train_loader: - dataset: - name: streaming_pubmed - remote: *data_remote - local: *data_local - split: train - tokenizer_name: *tokenizer_name - max_seq_len: *max_seq_len - group_method: concat - shuffle: true - drop_last: true - num_workers: 8 - pin_memory: true - prefetch_factor: 2 - persistent_workers: true - timeout: 0 - -eval_loader: - dataset: - name: streaming_pubmed - remote: *data_remote - local: *data_local - split: val - tokenizer_name: *tokenizer_name - max_seq_len: *max_seq_len - group_method: concat - shuffle: false - drop_last: false - num_workers: 8 - pin_memory: true - prefetch_factor: 2 - persistent_workers: true - timeout: 0 - -# Optimization -scheduler: - name: cosine_with_warmup - t_warmup: 100ba - alpha_f: 0.1 - -optimizer: - name: decoupled_adamw - lr: 6.0e-4 - betas: - - 0.9 - - 0.95 - eps: 1.0e-08 - weight_decay: 0.0 - -max_duration: 5000ba -eval_interval: 20ba -global_train_batch_size: 256 -grad_clip_norm: 1.0 - -# System -seed: 17 -device_train_microbatch_size: 16 # Hitting memory limits with 16 when resuming a run :( -precision: bf16 - -# FSDP -fsdp_config: - sharding_strategy: SHARD_GRAD_OP - min_params: 1e8 - mixed_precision: full - activation_checkpointing: false - activation_cpu_offload: false - verbose: true - -# Logging -callbacks: - speed_monitor: - window_size: 10 - lr_monitor: {} - # memory_monitor: {} - -loggers: - progress_bar: {} - wandb: - # entity: stanford-mercury - project: composer-gpt2-debug - # log_artifacts: true - - -# # # Checkpointing -# checkpoint_save_path: './{run_name}/checkpoints' -# checkpoint_save_interval: 4995ba -# num_checkpoints_to_keep: 1 # Important, this cleans up checkpoints saved to DISK - -# # WandB specific naming -# save_artifact_name: "{run_name}.pt" -# save_latest_artifact_name: "{run_name}.latest" - -# # Load from a WandB Artifact -# load_path: mosaicml-fsdp-flash-attention-demo-1.pt:ep0-ba30 -# load_object_store: -# wandb: -# entity: stanford-mercury -# project: mosaic-gpt2 - -# # Load from local filesystem -# load_path: ./gpt-125m/checkpoints/latest-rank{rank}.pt - -# # Load from remote object store -# load_path: gpt-125m/checkpoints/latest-rank{rank} -# load_object_store: -# s3: -# bucket: my-bucket -# prefix: my-folder From ad1c028a56a1d782d702bcec9414b5a3ad2384e1 Mon Sep 17 00:00:00 2001 From: Abhinav Venigalla Date: Wed, 19 Oct 2022 14:25:33 -0700 Subject: [PATCH 06/18] no eos token, concat directly --- llm/llm/data_pubmed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llm/llm/data_pubmed.py b/llm/llm/data_pubmed.py index 4d7ade35b..dcb6a3200 100644 --- a/llm/llm/data_pubmed.py +++ b/llm/llm/data_pubmed.py @@ -85,7 +85,7 @@ def __iter__(self) -> Iterator[Any]: iterator = super().__iter__() for sample in iterator: for k, v in sample.items(): - buffer[k] = buffer.get(k, []) + v + [self.tokenizer.eos_token_id] + buffer[k] = buffer.get(k, []) + v while len(buffer['input_ids']) >= self.max_seq_len: concat_sample = {} for k, v in buffer.items(): From e57cb6c59adb33894494c503b6106da164bf90ec Mon Sep 17 00:00:00 2001 From: Abhinav Venigalla Date: Tue, 1 Nov 2022 06:23:54 +0000 Subject: [PATCH 07/18] upgrade requirements.txt --- llm/requirements.txt | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/llm/requirements.txt b/llm/requirements.txt index d9d73e972..0b26193eb 100644 --- a/llm/requirements.txt +++ b/llm/requirements.txt @@ -1,4 +1,7 @@ -mosaicml[streaming] @ git+https://github.com/mosaicml/composer@feb856fd7ac05659cff3d1aa894155310ceb322a +torchvision<0.14 +torchtext<0.14 +torch<1.13 +mosaicml[streaming]==0.11.0 flash_attn @ git+https://github.com/HazyResearch/flash-attention.git@main transformers==4.21.3 datasets==2.4.0 From 13d742a76c1f87ba2c8cfc90c326f28bfd5f6d6c Mon Sep 17 00:00:00 2001 From: Abhinav Venigalla Date: Tue, 1 Nov 2022 08:33:26 +0000 Subject: [PATCH 08/18] add tests --- .../tests/gpt-125m-ctx-1024-no-dropout.json | 38 +++++++++++++ llm/llm/gpt.py | 7 ++- .../{hf_flash_gpt_2.py => hf_flash_gpt2.py} | 31 +++++----- llm/test_model.py | 57 +++++++++++++++++++ 4 files changed, 117 insertions(+), 16 deletions(-) create mode 100644 llm/hf_configs/tests/gpt-125m-ctx-1024-no-dropout.json rename llm/llm/{hf_flash_gpt_2.py => hf_flash_gpt2.py} (80%) create mode 100644 llm/test_model.py diff --git a/llm/hf_configs/tests/gpt-125m-ctx-1024-no-dropout.json b/llm/hf_configs/tests/gpt-125m-ctx-1024-no-dropout.json new file mode 100644 index 000000000..6ca6a66a0 --- /dev/null +++ b/llm/hf_configs/tests/gpt-125m-ctx-1024-no-dropout.json @@ -0,0 +1,38 @@ +{ + "activation_function": "gelu_new", + "architectures": [ + "GPT2LMHeadModel" + ], + "attn_pdrop": 0.0, + "bos_token_id": 50256, + "embd_pdrop": 0.0, + "eos_token_id": 50256, + "initializer_range": 0.02, + "layer_norm_epsilon": 1e-05, + "model_type": "gpt2", + "n_ctx": 1024, + "n_embd": 768, + "n_head": 12, + "n_inner": null, + "n_layer": 12, + "n_positions": 1024, + "reorder_and_upcast_attn": false, + "resid_pdrop": 0.0, + "scale_attn_by_inverse_layer_idx": true, + "scale_attn_weights": true, + "summary_activation": null, + "summary_first_dropout": 0.0, + "summary_proj_to_labels": true, + "summary_type": "cls_index", + "summary_use_proj": true, + "task_specific_params": { + "text-generation": { + "do_sample": true, + "max_length": 50 + } + }, + "torch_dtype": "float32", + "transformers_version": "4.21.3", + "use_cache": false, + "vocab_size": 50257 +} diff --git a/llm/llm/gpt.py b/llm/llm/gpt.py index 0ea07ea83..74f75d7cd 100644 --- a/llm/llm/gpt.py +++ b/llm/llm/gpt.py @@ -16,7 +16,8 @@ from flash_attn.flash_attention import FlashMHA from transformers.models.gpt2 import GPT2Config from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel -from .hf_flash_gpt_2 import GPT2FlashLMHeadModel + +from .hf_flash_gpt2 import GPT2FlashLMHeadModel def prepare_hf_gpt2_model_for_fsdp(model): @@ -66,7 +67,7 @@ def get_targets(self, batch): return targets def forward(self, batch): - return self.model(input_ids=batch['input_ids']).logits + return self.model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask']).logits def eval_forward(self, batch, outputs=None): return outputs if outputs is not None else self.forward(batch) @@ -83,4 +84,4 @@ def get_metrics(self, is_train=False): def update_metric(self, batch, outputs, metric): outputs = outputs.view(-1, outputs.size(-1)) targets = self.get_targets(batch).view(-1) - metric.update(outputs, targets) \ No newline at end of file + metric.update(outputs, targets) diff --git a/llm/llm/hf_flash_gpt_2.py b/llm/llm/hf_flash_gpt2.py similarity index 80% rename from llm/llm/hf_flash_gpt_2.py rename to llm/llm/hf_flash_gpt2.py index f09ae098d..87ca91f41 100644 --- a/llm/llm/hf_flash_gpt_2.py +++ b/llm/llm/hf_flash_gpt2.py @@ -15,13 +15,12 @@ # limitations under the License. """Modified HF GPT2 w/flash attention""" -import math import os from typing import Optional, Tuple, Union import torch from einops import rearrange -from flash_attn.flash_attention import FlashAttention +from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func from torch import nn from transformers.models.gpt2.configuration_gpt2 import GPT2Config from transformers.models.gpt2.modeling_gpt2 import ( @@ -32,17 +31,9 @@ class GPT2FlashAttention(GPT2Attention): def __init__(self, config, is_cross_attention=False, layer_idx=None): super().__init__(config=config, is_cross_attention=is_cross_attention, layer_idx=layer_idx) - self.inner_attn = FlashAttention(softmax_scale=None, attention_dropout=config.attn_pdrop) - if self.reorder_and_upcast_attn: - raise ValueError('GPT2FlashAttention does not support reorder_and_upcast_attn.') - if self.scale_attn_by_inverse_layer_idx: - raise ValueError('GPT2FlashAttention does not support scale_attn_by_inverse_layer_idx.') - if not self.scale_attn_weights: - raise ValueError('GPT2FlashAttention only supports scale_attn_weights=True.') + self.attn_pdrop = config.attn_pdrop def _attn(self, query, key, value, attention_mask=None, head_mask=None): - if head_mask is not None: - raise ValueError('GPT2FlashAttention._attn does not support "head_mask"') # rearrange to flash attention form key = rearrange(key, 'b h s d -> b s h d') value = rearrange(value, 'b h s d -> b s h d') @@ -52,9 +43,23 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): qkv = torch.stack([query,key,value], dim=2) assert qkv.dtype in [torch.float16, torch.bfloat16] - output, attn_weights = self.inner_attn(qkv, key_padding_mask=attention_mask, - need_weights=False, causal=True) + # flash attention logic + batch_size = qkv.shape[0] + seqlen = qkv.shape[1] + dk = qkv.shape[4] + qkv = rearrange(qkv, 'b s ... -> (b s) ...') + max_s = seqlen + cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, device=qkv.device) + attn_pdrop = self.attn_pdrop if self.training else 0.0 + softmax_scale = (1.0 / (dk ** 0.5)) if self.scale_attn_weights else 1.0 + softmax_scale = (softmax_scale / float(self.layer_idx + 1)) if self.scale_attn_by_inverse_layer_idx else softmax_scale + output = flash_attn_unpadded_qkvpacked_func( + qkv, cu_seqlens, max_s, attn_pdrop, + softmax_scale=softmax_scale, causal=True + ) + output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) output = rearrange(output, 'b s h d -> b h s d') + return output, None diff --git a/llm/test_model.py b/llm/test_model.py new file mode 100644 index 000000000..00a40d144 --- /dev/null +++ b/llm/test_model.py @@ -0,0 +1,57 @@ +import torch +from composer.utils import reproducibility +from transformers import DataCollatorForLanguageModeling +from transformers.models.gpt2 import GPT2Config +from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel +from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer + +from llm.hf_flash_gpt2 import GPT2FlashLMHeadModel + + +def test_fwd_bkw(config_path, autocast_device, autocast_dtype): + reproducibility.seed_all(42) + + # Build both models + shared_config = GPT2Config.from_json_file(config_path) + non_flash_model = GPT2LMHeadModel(shared_config) + flash_model = GPT2FlashLMHeadModel(shared_config) + + # Initialize with same parameters + non_flash_state_dict = non_flash_model.state_dict() + flash_model.load_state_dict(non_flash_state_dict) + + # Fake inputs + tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + tokenizer.add_special_tokens({'pad_token': '[PAD]'}) + fake_sample = tokenizer('Here is a fake sample of length 8') + collate_fn = DataCollatorForLanguageModeling(tokenizer, mlm=False) + fake_batch = collate_fn([fake_sample]) + + + # Move to device + non_flash_model = non_flash_model.to(autocast_device) + flash_model = flash_model.to(autocast_device) + fake_batch = { + k: v.to(autocast_device) + for k, v in fake_batch.items() + } + print (fake_batch) + + # Compare outputs + with torch.autocast(device_type=autocast_device, dtype=autocast_dtype): + non_flash_outputs = non_flash_model(**fake_batch).logits + flash_outputs = flash_model(**fake_batch).logits + + print ('#'*20) + print ('OUTPUTS') + print (non_flash_outputs) + print (flash_outputs) + print (torch.allclose(flash_outputs, non_flash_outputs, atol=5e-02)) + + + +config_path = './hf_configs/tests/gpt-125m-ctx-1024-no-dropout.json' + +autocast_device = 'cuda' +autocast_dtype = torch.bfloat16 +test_fwd_bkw(config_path, autocast_device, autocast_dtype) From fd9c049fc7ead498710bdeae405fe15347a9c9a4 Mon Sep 17 00:00:00 2001 From: Abhinav Venigalla Date: Tue, 1 Nov 2022 09:12:47 +0000 Subject: [PATCH 09/18] updates --- llm/hf_configs/final/gpt-125m-biotok.json | 38 ++++++ .../{ => old}/gpt-125m-ctx-1024.json | 0 llm/hf_configs/{ => old}/gpt-1b-ctx-1024.json | 0 .../{ => old}/gpt-350m-ctx-1024.json | 0 llm/hf_configs/{ => old}/gpt-3b-ctx-1024.json | 0 .../{ => old}/gpt-760m-ctx-1024.json | 0 llm/hf_configs/{ => old}/gpt-7b-ctx-1024.json | 0 .../{ => old}/mistral_gpt2_small.json | 0 .../{ => old}/noflash-gpt-125m-ctx-1024.json | 0 .../{ => old}/pubmed_gpt2_medium.json | 0 .../{ => old}/pubmed_mistral_gpt2_small.json | 0 llm/llm/gpt.py | 8 +- llm/main.py | 125 +++++++++--------- llm/yamls/final/gpt-125m-biotok.yaml | 105 +++++++++++++++ llm/yamls/{ => old}/gpt-125m-demo.yaml | 0 llm/yamls/{ => old}/gpt-125m.yaml | 0 llm/yamls/{ => old}/gpt-13b.yaml | 0 llm/yamls/{ => old}/gpt-1b.yaml | 0 llm/yamls/{ => old}/gpt-30b.yaml | 0 llm/yamls/{ => old}/gpt-350m.yaml | 0 llm/yamls/{ => old}/gpt-3b.yaml | 0 llm/yamls/{ => old}/gpt-70b.yaml | 0 llm/yamls/{ => old}/gpt-760m.yaml | 0 llm/yamls/{ => old}/gpt-7b.yaml | 0 .../{ => old}/gpt-mistral-125m-demo.yaml | 0 llm/yamls/{ => old}/gpt-mistral-125m.yaml | 0 llm/yamls/{ => old}/pubmed-gpt-125m.yaml | 0 llm/yamls/{ => old}/pubmed-gpt-350m.yaml | 0 llm/yamls/{ => old}/pubmed-gpt-3b.yaml | 0 .../{ => old}/pubmed-mistral-gpt-125m.yaml | 0 30 files changed, 206 insertions(+), 70 deletions(-) create mode 100644 llm/hf_configs/final/gpt-125m-biotok.json rename llm/hf_configs/{ => old}/gpt-125m-ctx-1024.json (100%) rename llm/hf_configs/{ => old}/gpt-1b-ctx-1024.json (100%) rename llm/hf_configs/{ => old}/gpt-350m-ctx-1024.json (100%) rename llm/hf_configs/{ => old}/gpt-3b-ctx-1024.json (100%) rename llm/hf_configs/{ => old}/gpt-760m-ctx-1024.json (100%) rename llm/hf_configs/{ => old}/gpt-7b-ctx-1024.json (100%) rename llm/hf_configs/{ => old}/mistral_gpt2_small.json (100%) rename llm/hf_configs/{ => old}/noflash-gpt-125m-ctx-1024.json (100%) rename llm/hf_configs/{ => old}/pubmed_gpt2_medium.json (100%) rename llm/hf_configs/{ => old}/pubmed_mistral_gpt2_small.json (100%) create mode 100644 llm/yamls/final/gpt-125m-biotok.yaml rename llm/yamls/{ => old}/gpt-125m-demo.yaml (100%) rename llm/yamls/{ => old}/gpt-125m.yaml (100%) rename llm/yamls/{ => old}/gpt-13b.yaml (100%) rename llm/yamls/{ => old}/gpt-1b.yaml (100%) rename llm/yamls/{ => old}/gpt-30b.yaml (100%) rename llm/yamls/{ => old}/gpt-350m.yaml (100%) rename llm/yamls/{ => old}/gpt-3b.yaml (100%) rename llm/yamls/{ => old}/gpt-70b.yaml (100%) rename llm/yamls/{ => old}/gpt-760m.yaml (100%) rename llm/yamls/{ => old}/gpt-7b.yaml (100%) rename llm/yamls/{ => old}/gpt-mistral-125m-demo.yaml (100%) rename llm/yamls/{ => old}/gpt-mistral-125m.yaml (100%) rename llm/yamls/{ => old}/pubmed-gpt-125m.yaml (100%) rename llm/yamls/{ => old}/pubmed-gpt-350m.yaml (100%) rename llm/yamls/{ => old}/pubmed-gpt-3b.yaml (100%) rename llm/yamls/{ => old}/pubmed-mistral-gpt-125m.yaml (100%) diff --git a/llm/hf_configs/final/gpt-125m-biotok.json b/llm/hf_configs/final/gpt-125m-biotok.json new file mode 100644 index 000000000..39a2e388a --- /dev/null +++ b/llm/hf_configs/final/gpt-125m-biotok.json @@ -0,0 +1,38 @@ +{ + "activation_function": "gelu_new", + "architectures": [ + "GPT2LMHeadModel" + ], + "attn_pdrop": 0.1, + "bos_token_id": 28895, + "embd_pdrop": 0.1, + "eos_token_id": 28895, + "initializer_range": 0.02, + "layer_norm_epsilon": 1e-05, + "model_type": "gpt2", + "n_ctx": 1024, + "n_embd": 768, + "n_head": 12, + "n_inner": null, + "n_layer": 12, + "n_positions": 1024, + "reorder_and_upcast_attn": false, + "resid_pdrop": 0.1, + "scale_attn_by_inverse_layer_idx": true, + "scale_attn_weights": true, + "summary_activation": null, + "summary_first_dropout": 0.1, + "summary_proj_to_labels": true, + "summary_type": "cls_index", + "summary_use_proj": true, + "task_specific_params": { + "text-generation": { + "do_sample": true, + "max_length": 50 + } + }, + "torch_dtype": "float32", + "transformers_version": "4.21.3", + "use_cache": false, + "vocab_size": 28896 +} diff --git a/llm/hf_configs/gpt-125m-ctx-1024.json b/llm/hf_configs/old/gpt-125m-ctx-1024.json similarity index 100% rename from llm/hf_configs/gpt-125m-ctx-1024.json rename to llm/hf_configs/old/gpt-125m-ctx-1024.json diff --git a/llm/hf_configs/gpt-1b-ctx-1024.json b/llm/hf_configs/old/gpt-1b-ctx-1024.json similarity index 100% rename from llm/hf_configs/gpt-1b-ctx-1024.json rename to llm/hf_configs/old/gpt-1b-ctx-1024.json diff --git a/llm/hf_configs/gpt-350m-ctx-1024.json b/llm/hf_configs/old/gpt-350m-ctx-1024.json similarity index 100% rename from llm/hf_configs/gpt-350m-ctx-1024.json rename to llm/hf_configs/old/gpt-350m-ctx-1024.json diff --git a/llm/hf_configs/gpt-3b-ctx-1024.json b/llm/hf_configs/old/gpt-3b-ctx-1024.json similarity index 100% rename from llm/hf_configs/gpt-3b-ctx-1024.json rename to llm/hf_configs/old/gpt-3b-ctx-1024.json diff --git a/llm/hf_configs/gpt-760m-ctx-1024.json b/llm/hf_configs/old/gpt-760m-ctx-1024.json similarity index 100% rename from llm/hf_configs/gpt-760m-ctx-1024.json rename to llm/hf_configs/old/gpt-760m-ctx-1024.json diff --git a/llm/hf_configs/gpt-7b-ctx-1024.json b/llm/hf_configs/old/gpt-7b-ctx-1024.json similarity index 100% rename from llm/hf_configs/gpt-7b-ctx-1024.json rename to llm/hf_configs/old/gpt-7b-ctx-1024.json diff --git a/llm/hf_configs/mistral_gpt2_small.json b/llm/hf_configs/old/mistral_gpt2_small.json similarity index 100% rename from llm/hf_configs/mistral_gpt2_small.json rename to llm/hf_configs/old/mistral_gpt2_small.json diff --git a/llm/hf_configs/noflash-gpt-125m-ctx-1024.json b/llm/hf_configs/old/noflash-gpt-125m-ctx-1024.json similarity index 100% rename from llm/hf_configs/noflash-gpt-125m-ctx-1024.json rename to llm/hf_configs/old/noflash-gpt-125m-ctx-1024.json diff --git a/llm/hf_configs/pubmed_gpt2_medium.json b/llm/hf_configs/old/pubmed_gpt2_medium.json similarity index 100% rename from llm/hf_configs/pubmed_gpt2_medium.json rename to llm/hf_configs/old/pubmed_gpt2_medium.json diff --git a/llm/hf_configs/pubmed_mistral_gpt2_small.json b/llm/hf_configs/old/pubmed_mistral_gpt2_small.json similarity index 100% rename from llm/hf_configs/pubmed_mistral_gpt2_small.json rename to llm/hf_configs/old/pubmed_mistral_gpt2_small.json diff --git a/llm/llm/gpt.py b/llm/llm/gpt.py index 74f75d7cd..02db22a6a 100644 --- a/llm/llm/gpt.py +++ b/llm/llm/gpt.py @@ -41,13 +41,11 @@ def __init__(self, cfg): # load GPT2 config from standard HF model config json hf_config = GPT2Config.from_json_file(cfg.hf_config) # build model with config - model_class = hf_config.architectures[0] - if model_class == 'GPT2LMHeadModel': - self.model = GPT2LMHeadModel(hf_config) - elif model_class == 'GPT2FlashLMHeadModel': + flash_attn = cfg.get('flash_attn', False) + if flash_attn: self.model = GPT2FlashLMHeadModel(hf_config) else: - raise ValueError(f'Not sure how to build model_class={model_class}') + self.model = GPT2LMHeadModel(hf_config) # Tag layers to make the model ready for FSDP prepare_hf_gpt2_model_for_fsdp(self.model) diff --git a/llm/main.py b/llm/main.py index ac7d6484a..4b1a1a845 100644 --- a/llm/main.py +++ b/llm/main.py @@ -3,29 +3,23 @@ import os import sys +import warnings from composer import Trainer from composer.callbacks import LRMonitor, MemoryMonitor, SpeedMonitor -from composer.loggers import ProgressBarLogger, WandBLogger +from composer.loggers import WandBLogger from composer.optim import DecoupledAdamW -from torch.optim import AdamW from composer.optim.scheduler import (ConstantWithWarmupScheduler, - CosineAnnealingWithWarmupScheduler, LinearWithWarmupScheduler) -from composer.utils import S3ObjectStore, dist, reproducibility + CosineAnnealingWithWarmupScheduler) +from composer.utils import dist, reproducibility from omegaconf import OmegaConf as om -import wandb from llm.data_pubmed import build_dataloader from llm.gpt import ComposerGPT def build_logger(name, kwargs): - if name == 'progress_bar': - return ProgressBarLogger( - progress_bar=kwargs.get('progress_bar', True), - log_to_console=kwargs.get('log_to_console', True), - ) - elif name == 'wandb': + if name == 'wandb': return WandBLogger(**kwargs) else: raise ValueError(f'Not sure how to build logger: {name}') @@ -40,6 +34,18 @@ def build_callback(name, kwargs): else: raise ValueError(f'Not sure how to build callback: {name}') +def build_optimizer(cfg, model): + if cfg.name == 'decoupled_adamw': + return DecoupledAdamW( + model.parameters(), + lr=cfg.lr, + betas=cfg.betas, + eps=cfg.eps, + weight_decay=cfg.weight_decay + ) + else: + raise ValueError(f'Not sure how to build optimizer: {cfg.name}') + def build_scheduler(cfg): if cfg.name == 'constant_with_warmup': @@ -49,14 +55,11 @@ def build_scheduler(cfg): return CosineAnnealingWithWarmupScheduler( t_warmup=cfg.t_warmup, alpha_f=cfg.alpha_f) - elif cfg.name == 'linear_with_warmup': - return LinearWithWarmupScheduler( - t_warmup=cfg.t_warmup) else: raise ValueError(f'Not sure how to build scheduler: {cfg.name}') # Coming soon: this conversion math will be done inside Composer Trainer rather than entrypoint -def get_batch_size_info(cfg): +def update_batch_size_info(cfg): global_train_batch_size = cfg.global_train_batch_size device_train_batch_size = global_train_batch_size // dist.get_world_size() device_train_microbatch_size = cfg.device_train_microbatch_size @@ -76,65 +79,65 @@ def get_batch_size_info(cfg): raise ValueError( f'Not sure how to parse {device_train_microbatch_size=}') - return device_train_batch_size, device_train_grad_accum, device_eval_batch_size, device_eval_microbatch_size + cfg.n_gpus = dist.get_world_size() + cfg.device_train_batch_size = device_train_batch_size + cfg.device_train_grad_accum = device_train_grad_accum + cfg.device_eval_batch_size = device_eval_batch_size + cfg.device_eval_microbatch_size = device_eval_microbatch_size + return cfg +def log_config(cfg): + print(om.to_yaml(cfg)) + if 'wandb' in cfg.loggers: + try: + import wandb + except ImportError as e: + raise e + if wandb.run: + wandb.config.update(om.to_container(cfg, resolve=True)) def main(cfg): - print("Training using config: ") - print(om.to_yaml(cfg)) reproducibility.seed_all(cfg.seed) + # Run Name + cfg.run_name = cfg.get('run_name', os.environ.get('COMPOSER_RUN_NAME', 'llm')) + + # Get batch size info + cfg = update_batch_size_info(cfg) + # Read FSDP Config as a dict fsdp_config = cfg.get('fsdp_config', None) fsdp_config = om.to_container(fsdp_config, resolve=True) if fsdp_config else None # Build Model - # For fast initialization, use `meta` device + # For fast initialization of MosaicGPT, use cfg.model.device='meta' print('Initializing model...') + warnings.filterwarnings(action='ignore', message='Torchmetrics v0.9 introduced a new argument class property') model = ComposerGPT(cfg=cfg.model) - n_params = sum(p.numel() for p in model.parameters()) - print(f'{n_params=:.2e}') - - # Get batch size info - device_train_batch_size, device_train_grad_accum, device_eval_batch_size, device_eval_microbatch_size = get_batch_size_info(cfg) + cfg.n_params = sum(p.numel() for p in model.parameters()) + print(f'{cfg.n_params=:.2e}') # Dataloaders print("Building train loader...") - train_loader = build_dataloader(cfg.train_loader, device_train_batch_size) + train_loader = build_dataloader(cfg.train_loader, cfg.device_train_batch_size) print("Building eval loader...") - eval_loader = build_dataloader(cfg.eval_loader, device_eval_batch_size) + eval_loader = build_dataloader(cfg.eval_loader, cfg.device_eval_batch_size) # Optimizer - if cfg.optimizer.name == 'adamw': - optimizer = AdamW( - model.parameters(), - lr=cfg.optimizer.lr, - betas=cfg.optimizer.betas, - eps=cfg.optimizer.eps, - weight_decay=cfg.optimizer.weight_decay) - elif cfg.optimizer.name == 'decoupled_adamw': - optimizer = DecoupledAdamW( - model.parameters(), - lr=cfg.optimizer.lr, - betas=cfg.optimizer.betas, - eps=cfg.optimizer.eps, - weight_decay=cfg.optimizer.weight_decay) - else: - raise ValueError(f'Requested unsupported optimizer: {cfg.optimizer.name}') - + optimizer = build_optimizer(cfg.optimizer, model) # Scheduler scheduler = build_scheduler(cfg.scheduler) # Loggers - loggers = [build_logger(name, logger_cfg) for name, logger_cfg in cfg.loggers.items()] + loggers = [build_logger(name, logger_cfg) for name, logger_cfg in cfg.get('loggers', {}).items()] # Callbacks - callbacks = [build_callback(name, callback_cfg) for name, callback_cfg in cfg.callbacks.items()] + callbacks = [build_callback(name, callback_cfg) for name, callback_cfg in cfg.get('callbacks', {}).items()] # Build the Trainer trainer = Trainer( - run_name=cfg.get('run_name', os.environ['COMPOSER_RUN_NAME']), + run_name=cfg.run_name, seed=cfg.seed, model=model, train_dataloader=train_loader, @@ -143,40 +146,32 @@ def main(cfg): schedulers=scheduler, max_duration=cfg.max_duration, eval_interval=cfg.eval_interval, + progress_bar=cfg.progress_bar, + log_to_console=cfg.log_to_console, loggers=loggers, callbacks=callbacks, precision=cfg.precision, grad_clip_norm=cfg.grad_clip_norm, - grad_accum=device_train_grad_accum, + grad_accum=cfg.device_train_grad_accum, fsdp_config=fsdp_config, save_folder=cfg.get('save_folder', None), - save_filename=cfg.get('save_filename', None), - save_latest_filename=cfg.get('save_latest_filename', None), save_interval=cfg.get('save_interval', '1000ba'), save_num_checkpoints_to_keep=cfg.get('save_num_checkpoints_to_keep', -1), - # load_path=cfg.get('load_path', None), - # load_weights_only=cfg.get('load_weights_only', False), - eval_subset_num_batches=cfg.get('eval_subset_num_batches', 5000), + load_path=cfg.get('load_path', None), + load_weights_only=cfg.get('load_weights_only', False), ) print("Logging config...") - config_dict = om.to_container(cfg, resolve=True) - config_dict.update({ - 'n_gpus': dist.get_world_size(), - 'n_params': n_params, - 'device_train_batch_size': device_train_batch_size, - 'device_eval_batch_size': device_eval_batch_size, - 'device_eval_microbatch_size': device_eval_microbatch_size, - }) - if wandb.run is not None: - wandb.config.update(config_dict) + log_config(cfg) print("Starting training...") trainer.fit() if __name__ == '__main__': - conf_path = sys.argv[1] - with open(conf_path) as f: - cfg = om.load(f) + yaml_path, args_list = sys.argv[1], sys.argv[2:] + with open(yaml_path) as f: + yaml_cfg = om.load(f) + cli_cfg = om.from_cli(args_list) + cfg = om.merge(yaml_cfg, cli_cfg) main(cfg) diff --git a/llm/yamls/final/gpt-125m-biotok.yaml b/llm/yamls/final/gpt-125m-biotok.yaml new file mode 100644 index 000000000..dedd1a4a1 --- /dev/null +++ b/llm/yamls/final/gpt-125m-biotok.yaml @@ -0,0 +1,105 @@ +data_remote: &data_remote s3://crfm-pubmed/pubmed-randomized +data_local: &data_local /tmp/mds-cache/pubmed-randomized +max_seq_len: &max_seq_len 1024 +tokenizer_name: &tokenizer_name stanford-crfm/pubmed_gpt_tokenizer + +# Run Name +run_name: gpt-125m + +# Model +model: + hf_config: hf_configs/final/gpt-125m-biotok.json + flash_attn: true + +# Dataloaders +train_loader: + dataset: + name: streaming_pubmed + remote: *data_remote + local: *data_local + split: train + tokenizer_name: *tokenizer_name + max_seq_len: *max_seq_len + group_method: concat + shuffle: true + drop_last: true + num_workers: 8 + pin_memory: true + prefetch_factor: 2 + persistent_workers: true + timeout: 0 + +eval_loader: + dataset: + name: streaming_pubmed + remote: *data_remote + local: *data_local + split: val + tokenizer_name: *tokenizer_name + max_seq_len: *max_seq_len + group_method: concat + shuffle: false + drop_last: false + num_workers: 8 + pin_memory: true + prefetch_factor: 2 + persistent_workers: true + timeout: 0 + +# Optimization +scheduler: + name: cosine_with_warmup + t_warmup: 100ba + alpha_f: 0.1 + +optimizer: + name: decoupled_adamw + lr: 6.0e-4 + betas: + - 0.9 + - 0.95 + eps: 1.0e-08 + weight_decay: 6.0e-5 + +max_duration: 100000ba +eval_interval: 5000ba +global_train_batch_size: 512 +grad_clip_norm: 1.0 + +# System +seed: 17 +device_train_microbatch_size: auto +# device_train_microbatch_size: 16 +precision: bf16 + +# FSDP +fsdp_config: + sharding_strategy: FULL_SHARD + min_params: 1e9 + mixed_precision: FULL + activation_checkpointing: false + activation_cpu_offload: false + verbose: true + +# Logging +progress_bar: false +log_to_console: true + +callbacks: + speed_monitor: + window_size: 10 + lr_monitor: {} + memory_monitor: {} + +loggers: + wandb: {} + +# Checkpoint to local filesystem or remote object store +save_interval: 5000ba +save_num_checkpoints_to_keep: 1 # Important, this cleans up checkpoints saved to DISK +save_folder: ./{run_name}/checkpoints +# save_folder: s3://crfm-pubmed/checkpoints/{run_name}/checkpoints + +# Load from local filesystem or remote object store +# load_path: ./gpt-125m/checkpoints/latest-rank{rank}.pt +# load_path: s3://my-bucket/my-folder/gpt-125m/checkpoints/latest-rank{rank}.pt diff --git a/llm/yamls/gpt-125m-demo.yaml b/llm/yamls/old/gpt-125m-demo.yaml similarity index 100% rename from llm/yamls/gpt-125m-demo.yaml rename to llm/yamls/old/gpt-125m-demo.yaml diff --git a/llm/yamls/gpt-125m.yaml b/llm/yamls/old/gpt-125m.yaml similarity index 100% rename from llm/yamls/gpt-125m.yaml rename to llm/yamls/old/gpt-125m.yaml diff --git a/llm/yamls/gpt-13b.yaml b/llm/yamls/old/gpt-13b.yaml similarity index 100% rename from llm/yamls/gpt-13b.yaml rename to llm/yamls/old/gpt-13b.yaml diff --git a/llm/yamls/gpt-1b.yaml b/llm/yamls/old/gpt-1b.yaml similarity index 100% rename from llm/yamls/gpt-1b.yaml rename to llm/yamls/old/gpt-1b.yaml diff --git a/llm/yamls/gpt-30b.yaml b/llm/yamls/old/gpt-30b.yaml similarity index 100% rename from llm/yamls/gpt-30b.yaml rename to llm/yamls/old/gpt-30b.yaml diff --git a/llm/yamls/gpt-350m.yaml b/llm/yamls/old/gpt-350m.yaml similarity index 100% rename from llm/yamls/gpt-350m.yaml rename to llm/yamls/old/gpt-350m.yaml diff --git a/llm/yamls/gpt-3b.yaml b/llm/yamls/old/gpt-3b.yaml similarity index 100% rename from llm/yamls/gpt-3b.yaml rename to llm/yamls/old/gpt-3b.yaml diff --git a/llm/yamls/gpt-70b.yaml b/llm/yamls/old/gpt-70b.yaml similarity index 100% rename from llm/yamls/gpt-70b.yaml rename to llm/yamls/old/gpt-70b.yaml diff --git a/llm/yamls/gpt-760m.yaml b/llm/yamls/old/gpt-760m.yaml similarity index 100% rename from llm/yamls/gpt-760m.yaml rename to llm/yamls/old/gpt-760m.yaml diff --git a/llm/yamls/gpt-7b.yaml b/llm/yamls/old/gpt-7b.yaml similarity index 100% rename from llm/yamls/gpt-7b.yaml rename to llm/yamls/old/gpt-7b.yaml diff --git a/llm/yamls/gpt-mistral-125m-demo.yaml b/llm/yamls/old/gpt-mistral-125m-demo.yaml similarity index 100% rename from llm/yamls/gpt-mistral-125m-demo.yaml rename to llm/yamls/old/gpt-mistral-125m-demo.yaml diff --git a/llm/yamls/gpt-mistral-125m.yaml b/llm/yamls/old/gpt-mistral-125m.yaml similarity index 100% rename from llm/yamls/gpt-mistral-125m.yaml rename to llm/yamls/old/gpt-mistral-125m.yaml diff --git a/llm/yamls/pubmed-gpt-125m.yaml b/llm/yamls/old/pubmed-gpt-125m.yaml similarity index 100% rename from llm/yamls/pubmed-gpt-125m.yaml rename to llm/yamls/old/pubmed-gpt-125m.yaml diff --git a/llm/yamls/pubmed-gpt-350m.yaml b/llm/yamls/old/pubmed-gpt-350m.yaml similarity index 100% rename from llm/yamls/pubmed-gpt-350m.yaml rename to llm/yamls/old/pubmed-gpt-350m.yaml diff --git a/llm/yamls/pubmed-gpt-3b.yaml b/llm/yamls/old/pubmed-gpt-3b.yaml similarity index 100% rename from llm/yamls/pubmed-gpt-3b.yaml rename to llm/yamls/old/pubmed-gpt-3b.yaml diff --git a/llm/yamls/pubmed-mistral-gpt-125m.yaml b/llm/yamls/old/pubmed-mistral-gpt-125m.yaml similarity index 100% rename from llm/yamls/pubmed-mistral-gpt-125m.yaml rename to llm/yamls/old/pubmed-mistral-gpt-125m.yaml From b44ab07df41eefeef47d2194043c92e78bc350ec Mon Sep 17 00:00:00 2001 From: Abhinav Venigalla Date: Tue, 1 Nov 2022 09:42:46 +0000 Subject: [PATCH 10/18] add 1b final --- llm/hf_configs/final/gpt-1b-biotok.json | 38 +++++++++ llm/yamls/final/gpt-1b-biotok.yaml | 105 ++++++++++++++++++++++++ 2 files changed, 143 insertions(+) create mode 100644 llm/hf_configs/final/gpt-1b-biotok.json create mode 100644 llm/yamls/final/gpt-1b-biotok.yaml diff --git a/llm/hf_configs/final/gpt-1b-biotok.json b/llm/hf_configs/final/gpt-1b-biotok.json new file mode 100644 index 000000000..4c72f1e48 --- /dev/null +++ b/llm/hf_configs/final/gpt-1b-biotok.json @@ -0,0 +1,38 @@ +{ + "activation_function": "gelu_new", + "architectures": [ + "GPT2LMHeadModel" + ], + "attn_pdrop": 0.1, + "bos_token_id": 28895, + "embd_pdrop": 0.1, + "eos_token_id": 28895, + "initializer_range": 0.02, + "layer_norm_epsilon": 1e-05, + "model_type": "gpt2", + "n_ctx": 1024, + "n_embd": 2048, + "n_head": 16, + "n_inner": null, + "n_layer": 24, + "n_positions": 1024, + "reorder_and_upcast_attn": false, + "resid_pdrop": 0.1, + "scale_attn_by_inverse_layer_idx": true, + "scale_attn_weights": true, + "summary_activation": null, + "summary_first_dropout": 0.1, + "summary_proj_to_labels": true, + "summary_type": "cls_index", + "summary_use_proj": true, + "task_specific_params": { + "text-generation": { + "do_sample": true, + "max_length": 50 + } + }, + "torch_dtype": "float32", + "transformers_version": "4.21.3", + "use_cache": false, + "vocab_size": 28896 +} diff --git a/llm/yamls/final/gpt-1b-biotok.yaml b/llm/yamls/final/gpt-1b-biotok.yaml new file mode 100644 index 000000000..8c1b47db1 --- /dev/null +++ b/llm/yamls/final/gpt-1b-biotok.yaml @@ -0,0 +1,105 @@ +data_remote: &data_remote s3://crfm-pubmed/pubmed-randomized +data_local: &data_local /tmp/mds-cache/pubmed-randomized +max_seq_len: &max_seq_len 1024 +tokenizer_name: &tokenizer_name stanford-crfm/pubmed_gpt_tokenizer + +# Run Name +run_name: gpt-1b + +# Model +model: + hf_config: hf_configs/final/gpt-1b-biotok.json + flash_attn: true + +# Dataloaders +train_loader: + dataset: + name: streaming_pubmed + remote: *data_remote + local: *data_local + split: train + tokenizer_name: *tokenizer_name + max_seq_len: *max_seq_len + group_method: concat + shuffle: true + drop_last: true + num_workers: 8 + pin_memory: true + prefetch_factor: 2 + persistent_workers: true + timeout: 0 + +eval_loader: + dataset: + name: streaming_pubmed + remote: *data_remote + local: *data_local + split: val + tokenizer_name: *tokenizer_name + max_seq_len: *max_seq_len + group_method: concat + shuffle: false + drop_last: false + num_workers: 8 + pin_memory: true + prefetch_factor: 2 + persistent_workers: true + timeout: 0 + +# Optimization +scheduler: + name: cosine_with_warmup + t_warmup: 100ba + alpha_f: 0.1 + +optimizer: + name: decoupled_adamw + lr: 2.0e-4 + betas: + - 0.9 + - 0.95 + eps: 1.0e-08 + weight_decay: 2.0e-5 + +max_duration: 100000ba +eval_interval: 5000ba +global_train_batch_size: 512 +grad_clip_norm: 1.0 + +# System +seed: 17 +device_train_microbatch_size: auto +# device_train_microbatch_size: 16 +precision: bf16 + +# FSDP +fsdp_config: + sharding_strategy: FULL_SHARD + min_params: 1e9 + mixed_precision: FULL + activation_checkpointing: false + activation_cpu_offload: false + verbose: true + +# Logging +progress_bar: false +log_to_console: true + +callbacks: + speed_monitor: + window_size: 10 + lr_monitor: {} + memory_monitor: {} + +loggers: + wandb: {} + +# Checkpoint to local filesystem or remote object store +save_interval: 5000ba +save_num_checkpoints_to_keep: 1 # Important, this cleans up checkpoints saved to DISK +save_folder: ./{run_name}/checkpoints +# save_folder: s3://crfm-pubmed/checkpoints/{run_name}/checkpoints + +# Load from local filesystem or remote object store +# load_path: ./gpt-125m/checkpoints/latest-rank{rank}.pt +# load_path: s3://my-bucket/my-folder/gpt-125m/checkpoints/latest-rank{rank}.pt From d408f7695eba63e1c8278e2d584e53ac7dfa7e34 Mon Sep 17 00:00:00 2001 From: Abhinav Venigalla Date: Tue, 1 Nov 2022 15:50:53 +0000 Subject: [PATCH 11/18] fix eval subset --- llm/main.py | 1 + llm/yamls/final/gpt-125m-biotok.yaml | 3 ++- llm/yamls/final/gpt-1b-biotok.yaml | 3 ++- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/llm/main.py b/llm/main.py index 4b1a1a845..a94fbb2f2 100644 --- a/llm/main.py +++ b/llm/main.py @@ -146,6 +146,7 @@ def main(cfg): schedulers=scheduler, max_duration=cfg.max_duration, eval_interval=cfg.eval_interval, + eval_subset_num_batches=cfg.eval_subset_num_batches, progress_bar=cfg.progress_bar, log_to_console=cfg.log_to_console, loggers=loggers, diff --git a/llm/yamls/final/gpt-125m-biotok.yaml b/llm/yamls/final/gpt-125m-biotok.yaml index dedd1a4a1..b3cee52fb 100644 --- a/llm/yamls/final/gpt-125m-biotok.yaml +++ b/llm/yamls/final/gpt-125m-biotok.yaml @@ -63,6 +63,7 @@ optimizer: max_duration: 100000ba eval_interval: 5000ba +eval_subset_num_batches: 5000ba global_train_batch_size: 512 grad_clip_norm: 1.0 @@ -82,7 +83,7 @@ fsdp_config: verbose: true # Logging -progress_bar: false +progress_bar: true log_to_console: true callbacks: diff --git a/llm/yamls/final/gpt-1b-biotok.yaml b/llm/yamls/final/gpt-1b-biotok.yaml index 8c1b47db1..fed3ab9da 100644 --- a/llm/yamls/final/gpt-1b-biotok.yaml +++ b/llm/yamls/final/gpt-1b-biotok.yaml @@ -63,6 +63,7 @@ optimizer: max_duration: 100000ba eval_interval: 5000ba +eval_subset_num_batches: 5000ba global_train_batch_size: 512 grad_clip_norm: 1.0 @@ -82,7 +83,7 @@ fsdp_config: verbose: true # Logging -progress_bar: false +progress_bar: true log_to_console: true callbacks: From a54dee5cbff283750c02990c99214a95ebdec79f Mon Sep 17 00:00:00 2001 From: Abhinav Venigalla Date: Tue, 1 Nov 2022 15:52:51 +0000 Subject: [PATCH 12/18] typo --- llm/yamls/final/gpt-125m-biotok.yaml | 8 ++++---- llm/yamls/final/gpt-1b-biotok.yaml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/llm/yamls/final/gpt-125m-biotok.yaml b/llm/yamls/final/gpt-125m-biotok.yaml index b3cee52fb..b010c98ed 100644 --- a/llm/yamls/final/gpt-125m-biotok.yaml +++ b/llm/yamls/final/gpt-125m-biotok.yaml @@ -63,7 +63,7 @@ optimizer: max_duration: 100000ba eval_interval: 5000ba -eval_subset_num_batches: 5000ba +eval_subset_num_batches: 5000 global_train_batch_size: 512 grad_clip_norm: 1.0 @@ -96,9 +96,9 @@ loggers: wandb: {} # Checkpoint to local filesystem or remote object store -save_interval: 5000ba -save_num_checkpoints_to_keep: 1 # Important, this cleans up checkpoints saved to DISK -save_folder: ./{run_name}/checkpoints +# save_interval: 5000ba +# save_num_checkpoints_to_keep: 1 # Important, this cleans up checkpoints saved to DISK +# save_folder: ./{run_name}/checkpoints # save_folder: s3://crfm-pubmed/checkpoints/{run_name}/checkpoints # Load from local filesystem or remote object store diff --git a/llm/yamls/final/gpt-1b-biotok.yaml b/llm/yamls/final/gpt-1b-biotok.yaml index fed3ab9da..5078be46d 100644 --- a/llm/yamls/final/gpt-1b-biotok.yaml +++ b/llm/yamls/final/gpt-1b-biotok.yaml @@ -63,7 +63,7 @@ optimizer: max_duration: 100000ba eval_interval: 5000ba -eval_subset_num_batches: 5000ba +eval_subset_num_batches: 5000 global_train_batch_size: 512 grad_clip_norm: 1.0 From a3f9fb6c65f4bed40d9310a7fc9c1d134d2055ac Mon Sep 17 00:00:00 2001 From: Abhinav Venigalla Date: Tue, 1 Nov 2022 15:53:22 +0000 Subject: [PATCH 13/18] another typo --- llm/yamls/final/gpt-125m-biotok.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llm/yamls/final/gpt-125m-biotok.yaml b/llm/yamls/final/gpt-125m-biotok.yaml index b010c98ed..e70a71f78 100644 --- a/llm/yamls/final/gpt-125m-biotok.yaml +++ b/llm/yamls/final/gpt-125m-biotok.yaml @@ -96,9 +96,9 @@ loggers: wandb: {} # Checkpoint to local filesystem or remote object store -# save_interval: 5000ba -# save_num_checkpoints_to_keep: 1 # Important, this cleans up checkpoints saved to DISK -# save_folder: ./{run_name}/checkpoints +save_interval: 5000ba +save_num_checkpoints_to_keep: 1 # Important, this cleans up checkpoints saved to DISK +save_folder: ./{run_name}/checkpoints # save_folder: s3://crfm-pubmed/checkpoints/{run_name}/checkpoints # Load from local filesystem or remote object store From 21dc5b9605e841e2723c80b1ef76d026b5cb96d7 Mon Sep 17 00:00:00 2001 From: Abhinav Venigalla Date: Wed, 2 Nov 2022 05:22:59 +0000 Subject: [PATCH 14/18] fix wandb resumption --- llm/main.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/llm/main.py b/llm/main.py index a94fbb2f2..408816fc2 100644 --- a/llm/main.py +++ b/llm/main.py @@ -4,6 +4,7 @@ import os import sys import warnings +from urllib.parse import urlparse from composer import Trainer from composer.callbacks import LRMonitor, MemoryMonitor, SpeedMonitor @@ -86,9 +87,19 @@ def update_batch_size_info(cfg): cfg.device_eval_microbatch_size = device_eval_microbatch_size return cfg +def get_load_params(cfg): + load_path = cfg.get('load_path', None) + if load_path and load_path.startswith('wandb'): + url = urlparse(load_path) + entity, project = url.netloc.split(':') + load_object_store = WandBLogger(entity=entity, project=project) + return load_path, load_object_store + else: + return load_path, None + def log_config(cfg): print(om.to_yaml(cfg)) - if 'wandb' in cfg.loggers: + if 'wandb' in cfg.get('loggers', {}): try: import wandb except ImportError as e: @@ -135,6 +146,9 @@ def main(cfg): # Callbacks callbacks = [build_callback(name, callback_cfg) for name, callback_cfg in cfg.get('callbacks', {}).items()] + # Load object store + load_path, load_object_store = get_load_params(cfg) + # Build the Trainer trainer = Trainer( run_name=cfg.run_name, @@ -158,7 +172,8 @@ def main(cfg): save_folder=cfg.get('save_folder', None), save_interval=cfg.get('save_interval', '1000ba'), save_num_checkpoints_to_keep=cfg.get('save_num_checkpoints_to_keep', -1), - load_path=cfg.get('load_path', None), + load_path=load_path, + load_object_store=load_object_store, load_weights_only=cfg.get('load_weights_only', False), ) From 0842306c77b803be20a5b1f86ad4a4d65990a025 Mon Sep 17 00:00:00 2001 From: Abhinav Venigalla Date: Wed, 2 Nov 2022 19:29:39 +0000 Subject: [PATCH 15/18] fix act ckpt --- llm/llm/gpt.py | 8 ++++---- llm/yamls/final/gpt-1b-biotok.yaml | 9 +++++---- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/llm/llm/gpt.py b/llm/llm/gpt.py index 02db22a6a..ee99542a3 100644 --- a/llm/llm/gpt.py +++ b/llm/llm/gpt.py @@ -15,7 +15,7 @@ from composer.models.base import ComposerModel from flash_attn.flash_attention import FlashMHA from transformers.models.gpt2 import GPT2Config -from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel +from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2LMHeadModel from .hf_flash_gpt2 import GPT2FlashLMHeadModel @@ -29,10 +29,10 @@ def prepare_hf_gpt2_model_for_fsdp(model): model.transformer.wte._fsdp_wrap = False model.lm_head._fsdp_wrap = False + # FSDP Wrap and Activation Checkpoint every GPT2Block - for block in model.transformer.h: - block._fsdp_wrap = True - block._activation_checkpointing = True + model.fsdp_wrap_fn = lambda module: isinstance(module, GPT2Block) + model.activation_checkpointing_fn = lambda module: isinstance(module, GPT2Block) class ComposerGPT(ComposerModel): diff --git a/llm/yamls/final/gpt-1b-biotok.yaml b/llm/yamls/final/gpt-1b-biotok.yaml index 5078be46d..67f033239 100644 --- a/llm/yamls/final/gpt-1b-biotok.yaml +++ b/llm/yamls/final/gpt-1b-biotok.yaml @@ -78,7 +78,7 @@ fsdp_config: sharding_strategy: FULL_SHARD min_params: 1e9 mixed_precision: FULL - activation_checkpointing: false + activation_checkpointing: true activation_cpu_offload: false verbose: true @@ -96,11 +96,12 @@ loggers: wandb: {} # Checkpoint to local filesystem or remote object store -save_interval: 5000ba -save_num_checkpoints_to_keep: 1 # Important, this cleans up checkpoints saved to DISK -save_folder: ./{run_name}/checkpoints +# save_interval: 5000ba +# save_num_checkpoints_to_keep: 1 # Important, this cleans up checkpoints saved to DISK +# save_folder: ./{run_name}/checkpoints # save_folder: s3://crfm-pubmed/checkpoints/{run_name}/checkpoints # Load from local filesystem or remote object store # load_path: ./gpt-125m/checkpoints/latest-rank{rank}.pt # load_path: s3://my-bucket/my-folder/gpt-125m/checkpoints/latest-rank{rank}.pt +# load_path: ./artifacts/1b-biotok-flash-fsdp-gpus-64-NbbmPS.checkpoints.ep0-ba5000-rank0.pt:v0/ep0-ba5000-rank0.pt \ No newline at end of file From a137933f65724996388325e2aa26037a28487f05 Mon Sep 17 00:00:00 2001 From: Abhinav Venigalla Date: Wed, 2 Nov 2022 20:59:55 +0000 Subject: [PATCH 16/18] cleanup, add 5b --- llm/hf_configs/final/gpt-3b-biotok.json | 38 +++++++++ llm/hf_configs/final/gpt-5b-biotok.json | 38 +++++++++ llm/yamls/final/gpt-1b-biotok.yaml | 10 +-- llm/yamls/final/gpt-3b-biotok-300k.yaml | 107 ++++++++++++++++++++++++ llm/yamls/final/gpt-5b-biotok-200k.yaml | 107 ++++++++++++++++++++++++ 5 files changed, 295 insertions(+), 5 deletions(-) create mode 100644 llm/hf_configs/final/gpt-3b-biotok.json create mode 100644 llm/hf_configs/final/gpt-5b-biotok.json create mode 100644 llm/yamls/final/gpt-3b-biotok-300k.yaml create mode 100644 llm/yamls/final/gpt-5b-biotok-200k.yaml diff --git a/llm/hf_configs/final/gpt-3b-biotok.json b/llm/hf_configs/final/gpt-3b-biotok.json new file mode 100644 index 000000000..f2d6bc213 --- /dev/null +++ b/llm/hf_configs/final/gpt-3b-biotok.json @@ -0,0 +1,38 @@ +{ + "activation_function": "gelu_new", + "architectures": [ + "GPT2LMHeadModel" + ], + "attn_pdrop": 0.1, + "bos_token_id": 28895, + "embd_pdrop": 0.1, + "eos_token_id": 28895, + "initializer_range": 0.02, + "layer_norm_epsilon": 1e-05, + "model_type": "gpt2", + "n_ctx": 1024, + "n_embd": 2560, + "n_head": 20, + "n_inner": null, + "n_layer": 32, + "n_positions": 1024, + "reorder_and_upcast_attn": false, + "resid_pdrop": 0.1, + "scale_attn_by_inverse_layer_idx": true, + "scale_attn_weights": true, + "summary_activation": null, + "summary_first_dropout": 0.1, + "summary_proj_to_labels": true, + "summary_type": "cls_index", + "summary_use_proj": true, + "task_specific_params": { + "text-generation": { + "do_sample": true, + "max_length": 50 + } + }, + "torch_dtype": "float32", + "transformers_version": "4.21.3", + "use_cache": false, + "vocab_size": 28896 +} diff --git a/llm/hf_configs/final/gpt-5b-biotok.json b/llm/hf_configs/final/gpt-5b-biotok.json new file mode 100644 index 000000000..a1589121e --- /dev/null +++ b/llm/hf_configs/final/gpt-5b-biotok.json @@ -0,0 +1,38 @@ +{ + "activation_function": "gelu_new", + "architectures": [ + "GPT2LMHeadModel" + ], + "attn_pdrop": 0.1, + "bos_token_id": 28895, + "embd_pdrop": 0.1, + "eos_token_id": 28895, + "initializer_range": 0.02, + "layer_norm_epsilon": 1e-05, + "model_type": "gpt2", + "n_ctx": 1024, + "n_embd": 4096, + "n_head": 32, + "n_inner": null, + "n_layer": 24, + "n_positions": 1024, + "reorder_and_upcast_attn": false, + "resid_pdrop": 0.1, + "scale_attn_by_inverse_layer_idx": true, + "scale_attn_weights": true, + "summary_activation": null, + "summary_first_dropout": 0.1, + "summary_proj_to_labels": true, + "summary_type": "cls_index", + "summary_use_proj": true, + "task_specific_params": { + "text-generation": { + "do_sample": true, + "max_length": 50 + } + }, + "torch_dtype": "float32", + "transformers_version": "4.21.3", + "use_cache": false, + "vocab_size": 28896 +} diff --git a/llm/yamls/final/gpt-1b-biotok.yaml b/llm/yamls/final/gpt-1b-biotok.yaml index 67f033239..d7f54f33f 100644 --- a/llm/yamls/final/gpt-1b-biotok.yaml +++ b/llm/yamls/final/gpt-1b-biotok.yaml @@ -69,8 +69,8 @@ grad_clip_norm: 1.0 # System seed: 17 -device_train_microbatch_size: auto -# device_train_microbatch_size: 16 +# device_train_microbatch_size: auto +device_train_microbatch_size: 16 precision: bf16 # FSDP @@ -96,9 +96,9 @@ loggers: wandb: {} # Checkpoint to local filesystem or remote object store -# save_interval: 5000ba -# save_num_checkpoints_to_keep: 1 # Important, this cleans up checkpoints saved to DISK -# save_folder: ./{run_name}/checkpoints +save_interval: 5000ba +save_num_checkpoints_to_keep: 1 # Important, this cleans up checkpoints saved to DISK +save_folder: ./{run_name}/checkpoints # save_folder: s3://crfm-pubmed/checkpoints/{run_name}/checkpoints # Load from local filesystem or remote object store diff --git a/llm/yamls/final/gpt-3b-biotok-300k.yaml b/llm/yamls/final/gpt-3b-biotok-300k.yaml new file mode 100644 index 000000000..d583c3aa5 --- /dev/null +++ b/llm/yamls/final/gpt-3b-biotok-300k.yaml @@ -0,0 +1,107 @@ +data_remote: &data_remote s3://crfm-pubmed/pubmed-randomized +data_local: &data_local /tmp/mds-cache/pubmed-randomized +max_seq_len: &max_seq_len 1024 +tokenizer_name: &tokenizer_name stanford-crfm/pubmed_gpt_tokenizer + +# Run Name +run_name: gpt-3b + +# Model +model: + hf_config: hf_configs/final/gpt-3b-biotok.json + flash_attn: true + +# Dataloaders +train_loader: + dataset: + name: streaming_pubmed + remote: *data_remote + local: *data_local + split: train + tokenizer_name: *tokenizer_name + max_seq_len: *max_seq_len + group_method: concat + shuffle: true + drop_last: true + num_workers: 8 + pin_memory: true + prefetch_factor: 2 + persistent_workers: true + timeout: 0 + +eval_loader: + dataset: + name: streaming_pubmed + remote: *data_remote + local: *data_local + split: val + tokenizer_name: *tokenizer_name + max_seq_len: *max_seq_len + group_method: concat + shuffle: false + drop_last: false + num_workers: 8 + pin_memory: true + prefetch_factor: 2 + persistent_workers: true + timeout: 0 + +# Optimization +scheduler: + name: cosine_with_warmup + t_warmup: 100ba + alpha_f: 0.1 + +optimizer: + name: decoupled_adamw + lr: 1.6e-4 + betas: + - 0.9 + - 0.95 + eps: 1.0e-08 + weight_decay: 1.6e-5 + +max_duration: 300000ba +eval_interval: 5000ba +eval_subset_num_batches: 1000 +global_train_batch_size: 1024 +grad_clip_norm: 1.0 + +# System +seed: 17 +# device_train_microbatch_size: auto +device_train_microbatch_size: 8 +precision: bf16 + +# FSDP +fsdp_config: + sharding_strategy: FULL_SHARD + min_params: 1e9 + mixed_precision: FULL + activation_checkpointing: true + activation_cpu_offload: false + verbose: true + +# Logging +progress_bar: true +log_to_console: true + +callbacks: + speed_monitor: + window_size: 10 + lr_monitor: {} + memory_monitor: {} + +loggers: + wandb: {} + +# Checkpoint to local filesystem or remote object store +save_interval: 5000ba +save_num_checkpoints_to_keep: 1 # Important, this cleans up checkpoints saved to DISK +save_folder: ./{run_name}/checkpoints +# save_folder: s3://crfm-pubmed/checkpoints/{run_name}/checkpoints + +# Load from local filesystem or remote object store +# load_path: ./gpt-125m/checkpoints/latest-rank{rank}.pt +# load_path: s3://my-bucket/my-folder/gpt-125m/checkpoints/latest-rank{rank}.pt +# load_path: ./artifacts/1b-biotok-flash-fsdp-gpus-64-NbbmPS.checkpoints.ep0-ba5000-rank0.pt:v0/ep0-ba5000-rank0.pt \ No newline at end of file diff --git a/llm/yamls/final/gpt-5b-biotok-200k.yaml b/llm/yamls/final/gpt-5b-biotok-200k.yaml new file mode 100644 index 000000000..cdf292c4e --- /dev/null +++ b/llm/yamls/final/gpt-5b-biotok-200k.yaml @@ -0,0 +1,107 @@ +data_remote: &data_remote s3://crfm-pubmed/pubmed-randomized +data_local: &data_local /tmp/mds-cache/pubmed-randomized +max_seq_len: &max_seq_len 1024 +tokenizer_name: &tokenizer_name stanford-crfm/pubmed_gpt_tokenizer + +# Run Name +run_name: gpt-5b + +# Model +model: + hf_config: hf_configs/final/gpt-5b-biotok.json + flash_attn: true + +# Dataloaders +train_loader: + dataset: + name: streaming_pubmed + remote: *data_remote + local: *data_local + split: train + tokenizer_name: *tokenizer_name + max_seq_len: *max_seq_len + group_method: concat + shuffle: true + drop_last: true + num_workers: 8 + pin_memory: true + prefetch_factor: 2 + persistent_workers: true + timeout: 0 + +eval_loader: + dataset: + name: streaming_pubmed + remote: *data_remote + local: *data_local + split: val + tokenizer_name: *tokenizer_name + max_seq_len: *max_seq_len + group_method: concat + shuffle: false + drop_last: false + num_workers: 8 + pin_memory: true + prefetch_factor: 2 + persistent_workers: true + timeout: 0 + +# Optimization +scheduler: + name: cosine_with_warmup + t_warmup: 100ba + alpha_f: 0.1 + +optimizer: + name: decoupled_adamw + lr: 1.4e-4 + betas: + - 0.9 + - 0.95 + eps: 1.0e-08 + weight_decay: 1.4e-5 + +max_duration: 200000ba +eval_interval: 5000ba +eval_subset_num_batches: 1000 +global_train_batch_size: 1024 +grad_clip_norm: 1.0 + +# System +seed: 17 +# device_train_microbatch_size: auto +device_train_microbatch_size: 8 +precision: bf16 + +# FSDP +fsdp_config: + sharding_strategy: FULL_SHARD + min_params: 1e9 + mixed_precision: FULL + activation_checkpointing: true + activation_cpu_offload: false + verbose: true + +# Logging +progress_bar: true +log_to_console: true + +callbacks: + speed_monitor: + window_size: 10 + lr_monitor: {} + memory_monitor: {} + +loggers: + wandb: {} + +# Checkpoint to local filesystem or remote object store +save_interval: 5000ba +save_num_checkpoints_to_keep: 1 # Important, this cleans up checkpoints saved to DISK +save_folder: ./{run_name}/checkpoints +# save_folder: s3://crfm-pubmed/checkpoints/{run_name}/checkpoints + +# Load from local filesystem or remote object store +# load_path: ./gpt-125m/checkpoints/latest-rank{rank}.pt +# load_path: s3://my-bucket/my-folder/gpt-125m/checkpoints/latest-rank{rank}.pt +# load_path: ./artifacts/1b-biotok-flash-fsdp-gpus-64-NbbmPS.checkpoints.ep0-ba5000-rank0.pt:v0/ep0-ba5000-rank0.pt \ No newline at end of file From 7df31d07cf7a7827b755272ff3e3383d7d6f7ca6 Mon Sep 17 00:00:00 2001 From: Abhinav Venigalla Date: Thu, 10 Nov 2022 12:54:21 -0800 Subject: [PATCH 17/18] add 7b option --- llm/hf_configs/final/gpt-7b-biotok.json | 38 +++++++++ llm/yamls/final/gpt-7b-biotk-300k.yaml | 107 ++++++++++++++++++++++++ 2 files changed, 145 insertions(+) create mode 100644 llm/hf_configs/final/gpt-7b-biotok.json create mode 100644 llm/yamls/final/gpt-7b-biotk-300k.yaml diff --git a/llm/hf_configs/final/gpt-7b-biotok.json b/llm/hf_configs/final/gpt-7b-biotok.json new file mode 100644 index 000000000..b7ec40b91 --- /dev/null +++ b/llm/hf_configs/final/gpt-7b-biotok.json @@ -0,0 +1,38 @@ +{ + "activation_function": "gelu_new", + "architectures": [ + "GPT2LMHeadModel" + ], + "attn_pdrop": 0.1, + "bos_token_id": 28895, + "embd_pdrop": 0.1, + "eos_token_id": 28895, + "initializer_range": 0.02, + "layer_norm_epsilon": 1e-05, + "model_type": "gpt2", + "n_ctx": 1024, + "n_embd": 4096, + "n_head": 32, + "n_inner": null, + "n_layer": 32, + "n_positions": 1024, + "reorder_and_upcast_attn": false, + "resid_pdrop": 0.1, + "scale_attn_by_inverse_layer_idx": true, + "scale_attn_weights": true, + "summary_activation": null, + "summary_first_dropout": 0.1, + "summary_proj_to_labels": true, + "summary_type": "cls_index", + "summary_use_proj": true, + "task_specific_params": { + "text-generation": { + "do_sample": true, + "max_length": 50 + } + }, + "torch_dtype": "float32", + "transformers_version": "4.21.3", + "use_cache": false, + "vocab_size": 28896 +} diff --git a/llm/yamls/final/gpt-7b-biotk-300k.yaml b/llm/yamls/final/gpt-7b-biotk-300k.yaml new file mode 100644 index 000000000..6a884fe13 --- /dev/null +++ b/llm/yamls/final/gpt-7b-biotk-300k.yaml @@ -0,0 +1,107 @@ +data_remote: &data_remote s3://crfm-pubmed/pubmed-randomized +data_local: &data_local /tmp/mds-cache/pubmed-randomized +max_seq_len: &max_seq_len 1024 +tokenizer_name: &tokenizer_name stanford-crfm/pubmed_gpt_tokenizer + +# Run Name +run_name: gpt-7b + +# Model +model: + hf_config: hf_configs/final/gpt-3b-biotok.json + flash_attn: true + +# Dataloaders +train_loader: + dataset: + name: streaming_pubmed + remote: *data_remote + local: *data_local + split: train + tokenizer_name: *tokenizer_name + max_seq_len: *max_seq_len + group_method: concat + shuffle: true + drop_last: true + num_workers: 8 + pin_memory: true + prefetch_factor: 2 + persistent_workers: true + timeout: 0 + +eval_loader: + dataset: + name: streaming_pubmed + remote: *data_remote + local: *data_local + split: val + tokenizer_name: *tokenizer_name + max_seq_len: *max_seq_len + group_method: concat + shuffle: false + drop_last: false + num_workers: 8 + pin_memory: true + prefetch_factor: 2 + persistent_workers: true + timeout: 0 + +# Optimization +scheduler: + name: cosine_with_warmup + t_warmup: 100ba + alpha_f: 0.1 + +optimizer: + name: decoupled_adamw + lr: 1.2e-4 + betas: + - 0.9 + - 0.95 + eps: 1.0e-08 + weight_decay: 1.2e-5 + +max_duration: 300000ba +eval_interval: 5000ba +eval_subset_num_batches: 1000 +global_train_batch_size: 2048 +grad_clip_norm: 1.0 + +# System +seed: 17 +# device_train_microbatch_size: auto +device_train_microbatch_size: 4 +precision: bf16 + +# FSDP +fsdp_config: + sharding_strategy: FULL_SHARD + min_params: 1e9 + mixed_precision: FULL + activation_checkpointing: true + activation_cpu_offload: false + verbose: true + +# Logging +progress_bar: true +log_to_console: true + +callbacks: + speed_monitor: + window_size: 10 + lr_monitor: {} + memory_monitor: {} + +loggers: + wandb: {} + +# Checkpoint to local filesystem or remote object store +save_interval: 5000ba +save_num_checkpoints_to_keep: 1 # Important, this cleans up checkpoints saved to DISK +save_folder: ./{run_name}/checkpoints +# save_folder: s3://crfm-pubmed/checkpoints/{run_name}/checkpoints + +# Load from local filesystem or remote object store +# load_path: ./gpt-125m/checkpoints/latest-rank{rank}.pt +# load_path: s3://my-bucket/my-folder/gpt-125m/checkpoints/latest-rank{rank}.pt +# load_path: ./artifacts/1b-biotok-flash-fsdp-gpus-64-NbbmPS.checkpoints.ep0-ba5000-rank0.pt:v0/ep0-ba5000-rank0.pt \ No newline at end of file From 4507fd79988ba20733bc62e4efc7cb9e7047e5b8 Mon Sep 17 00:00:00 2001 From: Abhinav Venigalla Date: Thu, 10 Nov 2022 13:01:16 -0800 Subject: [PATCH 18/18] bugfixes --- .../final/{gpt-7b-biotk-300k.yaml => gpt-7b-biotok-150k.yaml} | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) rename llm/yamls/final/{gpt-7b-biotk-300k.yaml => gpt-7b-biotok-150k.yaml} (96%) diff --git a/llm/yamls/final/gpt-7b-biotk-300k.yaml b/llm/yamls/final/gpt-7b-biotok-150k.yaml similarity index 96% rename from llm/yamls/final/gpt-7b-biotk-300k.yaml rename to llm/yamls/final/gpt-7b-biotok-150k.yaml index 6a884fe13..dd4a89f89 100644 --- a/llm/yamls/final/gpt-7b-biotk-300k.yaml +++ b/llm/yamls/final/gpt-7b-biotok-150k.yaml @@ -8,7 +8,7 @@ run_name: gpt-7b # Model model: - hf_config: hf_configs/final/gpt-3b-biotok.json + hf_config: hf_configs/final/gpt-7b-biotok.json flash_attn: true # Dataloaders @@ -61,7 +61,7 @@ optimizer: eps: 1.0e-08 weight_decay: 1.2e-5 -max_duration: 300000ba +max_duration: 150000ba # 300B tokens eval_interval: 5000ba eval_subset_num_batches: 1000 global_train_batch_size: 2048