Skip to content

Commit

Permalink
init of linear layer in starcoder
Browse files Browse the repository at this point in the history
  • Loading branch information
haeggee committed Aug 6, 2024
1 parent 91acdc0 commit df3befc
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 43 deletions.
33 changes: 29 additions & 4 deletions src/nanotron/config/models_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def as_starcoder2(self) -> Starcoder2Config:
def n_inner(self):
return self.intermediate_size


@dataclass
class GPT3MoEConfig:
"""Configuration for a GPT3 __MoE__ model"""
Expand Down Expand Up @@ -208,27 +209,51 @@ class GPT3MoEConfig:
moe_z_loss_weight: float = 0.001
moe_glu: bool = False


def as_gpt3(self) -> GPT3Config:
config = dict(**vars(self))

# Moe
# Moe
del config["is_moe"]
del config["moe_num_experts"]
del config["num_experts_per_tok"]
del config["moe_loss_weight"]
del config["moe_z_loss_weight"]
del config["moe_glu"]

if "_is_using_mup" in config:
del config["_is_using_mup"]
return GPT3Config(**config)

def as_starcoder2(self) -> Starcoder2Config:
# same as gpt3 conversion above
config = dict(**vars(self))
del config["sinusoidal_position_embedding"]
del config["use_spda"]
del config["position_embedding_offset"]
del config["act_pdrop"]
del config["scale_embedding"]

# Moe
del config["is_moe"]
del config["moe_num_experts"]
del config["num_experts_per_tok"]
del config["moe_loss_weight"]
del config["moe_z_loss_weight"]
del config["moe_glu"]

if "_is_using_mup" in config:
del config["_is_using_mup"]
return GPT3Config(
**config
return Starcoder2Config(
grouped_query=True, num_kv_heads=self.num_attention_heads, use_rotary_embeddings=False, **config
)

@property
def n_inner(self):
return self.intermediate_size

@property
def hidden_act(self):
return self.activation_function


NanotronConfigs = LlamaConfig | Starcoder2Config | GPT3Config | GPT3MoEConfig
75 changes: 36 additions & 39 deletions src/nanotron/models/gpt3_moe.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,30 @@
"""PyTorch GPT-3 MoE model."""

import math
from contextlib import contextmanager
from typing import Dict, Optional, Union

import torch
from torch import nn
from torch.nn import functional as F

from nanotron import distributed as dist
from nanotron.config import GPT3MoEConfig, ParallelismArgs, GPT3Config
from nanotron.generation.generate_store import AttachableStore
from nanotron.config import GPT3Config, GPT3MoEConfig, ParallelismArgs
from nanotron.models import gpt3
from nanotron.models.gpt3 import CausalSelfAttention, GPT3ForTraining, GPT3Model, dropout_add_fused_train
from nanotron.models.gpt3 import GPTBlock as GPT3Block
from nanotron.models.moe import (
dMoE,
)
from nanotron.models.gpt3 import CausalSelfAttention, GPTModel, PositionEmbedding, dropout_add_fused_train, GPT3ForTraining
from nanotron.models.gpt3 import GPTBlock as GPT3Block
from nanotron.nn.layer_norm import TritonLayerNorm
from nanotron.parallel import ParallelContext
from nanotron.parallel.pipeline_parallel.block import PipelineBlock
from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer
from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode
from nanotron.parallel.tensor_parallel.nn import TensorParallelColumnLinear, TensorParallelEmbedding
from nanotron.parallel.tensor_parallel.nn import TensorParallelColumnLinear
from nanotron.random import RandomStates, branch_random_state


@contextmanager
def replace_decoder(gpt3config: GPT3MoEConfig):
def replace_moe_decoder(gpt3config: GPT3MoEConfig):
orig = gpt3.PipelineBlock
try:

Expand All @@ -37,7 +34,7 @@ def create_pp_block(module_builder, module_kwargs, **kwargs):
# Let's return a PipelineBlock with a GPT3Block instead.
# This also requires to replace starcoders2's config with gpt3's config.
module_kwargs["config"] = gpt3config
return orig(module_builder=GPTBlock, module_kwargs=module_kwargs, **kwargs)
return orig(module_builder=GPT3MoEBlock, module_kwargs=module_kwargs, **kwargs)
# Else, they are setting up other modules, which we also want unchanged.
return orig(module_builder=module_builder, module_kwargs=module_kwargs, **kwargs)

Expand All @@ -48,24 +45,25 @@ def create_pp_block(module_builder, module_kwargs, **kwargs):


@contextmanager
def replace_gpt3model(gpt3moeconfig: GPT3MoEConfig):
orig = gpt3.GPTModel
def replace_gpt3_moe_model(gpt3moeconfig: GPT3MoEConfig):
orig = gpt3.GPT3Model
try:

def create_gptmodel(
def create_moe_model(
config: GPT3Config,
parallel_context: ParallelContext,
parallel_config: Optional[ParallelismArgs],
random_states: RandomStates,
):
return GPT3MoEModel(gpt3moeconfig, parallel_context, parallel_config, random_states)

gpt3.GPTModel = create_gptmodel
gpt3.GPT3Model = create_moe_model
yield
finally:
gpt3.GPTModel = orig
gpt3.GPT3Model = orig

class GPTBlock(nn.Module):

class GPT3MoEBlock(nn.Module):
def __init__(
self,
config: GPT3MoEConfig,
Expand All @@ -75,25 +73,24 @@ def __init__(
random_states: RandomStates,
layer_idx: int,
):
super(GPTBlock, self).__init__()
super(GPT3MoEBlock, self).__init__()
self.ln_1 = TritonLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.attn = CausalSelfAttention(
config=config, parallel_config=parallel_config, tp_pg=tp_pg, layer_idx=layer_idx
)
self.attn_dropout = config.attn_pdrop

self.ln_2 = TritonLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)

self.ff = dMoE(
config=config,
parallel_config=parallel_config,
parallel_context=parallel_context,
config=config,
parallel_config=parallel_config,
parallel_context=parallel_context,
)
self.ff_dropout = config.resid_pdrop
self.random_states = random_states
self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE


def forward(
self,
hidden_states: torch.Tensor | TensorPointer,
Expand Down Expand Up @@ -135,29 +132,26 @@ def forward(
# No need for random state context manager
hidden_states = hidden_states + residual

return {
"hidden_states": hidden_states,
"sequence_mask": output["sequence_mask"],
"aux_losses": aux_losses
}
return {"hidden_states": hidden_states, "sequence_mask": output["sequence_mask"], "aux_losses": aux_losses}


class GPT3MoEModel(GPTModel):
class GPT3MoEModel(GPT3Model):
def __init__(
self,
config: GPT3MoEConfig,
parallel_context: ParallelContext,
parallel_config: Optional[ParallelismArgs],
random_states: RandomStates,
):
with replace_decoder(config):
with replace_moe_decoder(config):
super().__init__(config.as_gpt3(), parallel_context, parallel_config, random_states)

# need to adapt the decoder list because we pass the aux_losses around
self.decoder = nn.ModuleList(
[
PipelineBlock(
p2p=self.p2p,
module_builder=GPTBlock,
module_builder=GPT3MoEBlock,
module_kwargs={
"config": config,
"parallel_config": parallel_config,
Expand All @@ -172,6 +166,7 @@ def __init__(
for layer_idx in range(config.num_hidden_layers)
]
)

def forward(
self,
input_ids: torch.Tensor | TensorPointer, # [batch_size, seq_length]
Expand Down Expand Up @@ -204,7 +199,7 @@ def forward(

fp32_sharded_logits = self.cast_to_fp32(x=sharded_logits)["output"]

return fp32_sharded_logits, hidden_encoder_states["aux_losses"]
return {"sharded_logits": fp32_sharded_logits, "aux_losses": hidden_encoder_states["aux_losses"]}


class GPT3MoEForTraining(GPT3ForTraining):
Expand All @@ -215,7 +210,7 @@ def __init__(
parallel_config: Optional[ParallelismArgs],
random_states: RandomStates,
):
with replace_gpt3model(config):
with replace_gpt3_moe_model(config):
super().__init__(config.as_gpt3(), parallel_context, parallel_config, random_states)
self.config = config

Expand Down Expand Up @@ -249,29 +244,31 @@ def forward(
label_ids=label_ids,
label_mask=label_mask,
)
if isinstance(output['aux_losses'], dict):

if isinstance(output["aux_losses"], dict):
for key, value in output["aux_losses"].items():
loss[key] = value
return loss

def get_block_compute_costs(self):
"""Computes the compute cost of each block in the model so that we can do a better job of load balancing."""
model_config = self.config
d_ff = model_config.n_inner if model_config.intermediate_size is not None else 4 * model_config.hidden_size
d_qkv = model_config.hidden_size // model_config.num_attention_heads
# active experts + routing
mlp_cost = 2 * d_ff * model_config.hidden_size * model_config.num_experts_per_tok \
mlp_cost = (
2 * d_ff * model_config.hidden_size * model_config.num_experts_per_tok
+ model_config.hidden_size * model_config.moe_num_experts
)
att_cost = 4 * model_config.num_attention_heads * d_qkv * model_config.hidden_size
block_compute_costs = {
# CausalSelfAttention (qkv proj + attn out) + MLP
GPTBlock: att_cost + mlp_cost,
GPT3MoEBlock: att_cost + mlp_cost,
# This is the last lm_head
TensorParallelColumnLinear: model_config.vocab_size * model_config.hidden_size,
}
return block_compute_costs

def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size):
"""Get flops per second for a given model"""
world_size = self.parallel_context.world_pg.size()
Expand All @@ -291,7 +288,7 @@ def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch
model_flops_per_s = model_flops / (iteration_time_in_sec * world_size * 1e12)
hardware_flops_per_s = hardware_flops / (iteration_time_in_sec * world_size * 1e12)
return model_flops_per_s, hardware_flops_per_s


def get_flops(
num_layers,
Expand Down
7 changes: 7 additions & 0 deletions src/nanotron/models/starcoder2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1517,6 +1517,13 @@ def init_model_randomly(self, config):
module.bias.zero_()
else:
raise ValueError(f"Who the fuck is {param_name}?")
elif isinstance(module, nn.Linear):
if "weight" == param_name:
nn.init.normal_(module.weight, mean=0.0, std=std)
elif "bias" == param_name:
module.bias.zero_()
else:
raise ValueError(f"Who the fuck is {param_name}?")
elif isinstance(module, TensorParallelRowLinear):
if "weight" == param_name:
nn.init.normal_(module.weight, mean=0.0, std=sigma / math.sqrt(2 * num_layers))
Expand Down

0 comments on commit df3befc

Please sign in to comment.