From 3967beeb65392a9dd9ae66751ea1ce791bea7a6d Mon Sep 17 00:00:00 2001 From: Alexander Hagele Date: Mon, 5 Aug 2024 11:50:53 +0200 Subject: [PATCH] correct block costs and flops --- src/nanotron/models/gpt3_moe.py | 121 ++++++++++++++++++++++++++++++-- 1 file changed, 114 insertions(+), 7 deletions(-) diff --git a/src/nanotron/models/gpt3_moe.py b/src/nanotron/models/gpt3_moe.py index 6fdb4669..2dc997bc 100644 --- a/src/nanotron/models/gpt3_moe.py +++ b/src/nanotron/models/gpt3_moe.py @@ -207,9 +207,6 @@ def forward( return fp32_sharded_logits, hidden_encoder_states["aux_losses"] -# TODO: maybe reimplement: -# - get_block_compute_costs -# - get_flops_per_sec class GPT3MoEForTraining(GPT3ForTraining): def __init__( self, @@ -258,17 +255,127 @@ def forward( loss[key] = value return loss - # TODO: adapt with MoE costs def get_block_compute_costs(self): """Computes the compute cost of each block in the model so that we can do a better job of load balancing.""" model_config = self.config d_ff = model_config.n_inner if model_config.intermediate_size is not None else 4 * model_config.hidden_size d_qkv = model_config.hidden_size // model_config.num_attention_heads + # active experts + routing + mlp_cost = 2 * d_ff * model_config.hidden_size * model_config.num_experts_per_tok \ + + model_config.hidden_size * model_config.moe_num_experts + att_cost = 4 * model_config.num_attention_heads * d_qkv * model_config.hidden_size block_compute_costs = { # CausalSelfAttention (qkv proj + attn out) + MLP - GPTBlock: 4 * model_config.num_attention_heads * d_qkv * model_config.hidden_size - + 2 * d_ff * model_config.hidden_size, + GPTBlock: att_cost + mlp_cost, # This is the last lm_head TensorParallelColumnLinear: model_config.vocab_size * model_config.hidden_size, } - return block_compute_costs \ No newline at end of file + return block_compute_costs + + def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size): + """Get flops per second for a given model""" + world_size = self.parallel_context.world_pg.size() + model_flops, hardware_flops = get_flops( + num_layers=self.config.num_hidden_layers, + hidden_size=self.config.hidden_size, + num_heads=self.config.num_attention_heads, + vocab_size=self.config.vocab_size, + ffn_hidden_size=self.config.n_inner if self.config.n_inner is not None else 4 * self.config.hidden_size, + seq_len=sequence_length, + batch_size=global_batch_size, + kv_channels=None, + glu_activation=False, + num_experts=self.config.moe_num_experts, + num_experts_per_tok=self.config.num_experts_per_tok, + ) + model_flops_per_s = model_flops / (iteration_time_in_sec * world_size * 1e12) + hardware_flops_per_s = hardware_flops / (iteration_time_in_sec * world_size * 1e12) + return model_flops_per_s, hardware_flops_per_s + + +def get_flops( + num_layers, + hidden_size, + num_heads, + vocab_size, + seq_len, + kv_channels=None, + ffn_hidden_size=None, + batch_size=1, + glu_activation=False, + num_experts=1, + num_experts_per_tok=1, +): + """Counts flops in an decoder-only model + Args: + num_layers: number of decoder layers + hidden_size: hidden size of the model + num_heads: number of heads in the model + kv_channels: hidden size of the key and value heads + ffn_hidden_size: hidden size of the FFN + vocab_size: size of the vocabulary + seq_len: sequence length of the decoder + batch_size: batch size + glu_activation: Whether to use GLU activation in FFN. Check T5 v1.1 for more info. + num_experts_per_tok: number of experts per token in the MoE layer + Returns: + model_flops: flops in the model (should be independent of the hardware and model implementation) + hardware_flops: flops in the hardware (actual flops performed on the hardware). Check 6.3 in https://arxiv.org/pdf/2205.05198.pdf + """ + + if kv_channels is None: + assert hidden_size % num_heads == 0 + kv_channels = hidden_size // num_heads + if ffn_hidden_size is None: + ffn_hidden_size = 4 * hidden_size + + # In the following we mark the reduced dimension with parentheses + # decoder + # self attention (MQA) + ## q projection + decoder_q_proj_flops_fwd = 2 * num_layers * batch_size * seq_len * (hidden_size) * num_heads * kv_channels + ## kv projection, shared across heads + decoder_kv_proj_flops_fwd = 2 * num_layers * batch_size * seq_len * (hidden_size) * 2 * kv_channels + ## qk logits + decoder_qk_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (kv_channels) * seq_len + ### SWA (sliding window attention / local attention) + # window_size = 4096 + # decoder_qk_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (kv_channels) * window_size + ## v logits + decoder_v_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (seq_len) * kv_channels + # decoder_v_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (window_size) * kv_channels + ## attn out + decoder_attn_out_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (kv_channels) * hidden_size + # FF + ## 1st layer + decoder_ffn_1_flops_fwd = 2 * num_layers * batch_size * seq_len * (hidden_size) * ffn_hidden_size + if glu_activation: + # 3 matmuls instead of 2 in FFN + # ref. https://arxiv.org/pdf/2002.05202.pdf + # Used for example in T5 v1.1 + decoder_ffn_1_flops_fwd = 4 * num_layers * batch_size * seq_len * (hidden_size) * ffn_hidden_size + ## 2nd layer + decoder_ffn_2_flops_fwd = 2 * num_layers * batch_size * seq_len * (ffn_hidden_size) * hidden_size + # MoE router + decoder_ffn_router_flops_fwd = 2 * num_layers * batch_size * seq_len * (hidden_size) * num_experts + + decoder_flops_fwd = ( + decoder_q_proj_flops_fwd + + decoder_kv_proj_flops_fwd + + decoder_qk_logits_flops_fwd + + decoder_v_logits_flops_fwd + + decoder_attn_out_flops_fwd + + decoder_ffn_1_flops_fwd * num_experts_per_tok + + decoder_ffn_2_flops_fwd * num_experts_per_tok + + decoder_ffn_router_flops_fwd + ) + + # lm head + lm_head_flops_fwd = 2 * batch_size * seq_len * (hidden_size) * vocab_size + + # the bwd pass requires double the flops in case of matmuls to calculate the gradients with respect to + # both input and weight tensors + model_flops = 3 * (decoder_flops_fwd + lm_head_flops_fwd) # 1 for fwd + 2 for bwd + + hardware_flops = model_flops # TODO @nouamanetazi: This is a placeholder for now + return model_flops, hardware_flops