diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index 72158651..d882b2ef 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -262,4 +262,97 @@ def hidden_act(self): return self.activation_function -NanotronConfigs = LlamaConfig | Starcoder2Config | GPT3Config | GPT3MoEConfig +@dataclass +class GPT3LangMoEConfig: + """Configuration for a GPT3 __MoE__ model with language aware gating""" + + activation_function: str = "gelu" + attn_pdrop: float = 0.1 + embd_pdrop: float = 0.1 + eos_token_id: int = 49152 + hidden_size: int = 2048 + intermediate_size: Optional[int] = None + layer_norm_epsilon: float = 1e-05 + max_position_embeddings: int = 4096 + num_attention_heads: int = 16 + num_hidden_layers: int = 24 + resid_pdrop: float = 0.1 + scale_attention_softmax_in_fp32: bool = True + scale_attn_weights: bool = True + vocab_size: int = 49280 + sinusoidal_position_embedding: bool = True + position_embedding_offset: int = 2 + use_spda: bool = False + act_pdrop: float = 0.0 + scale_embedding: bool = True + # MoE specific + is_moe: bool = True + moe_num_experts: int = 1 + num_experts_per_tok: int = 1 + moe_loss_weight: float = 0.01 + moe_z_loss_weight: float = 0.001 + moe_glu: bool = False + + # Language aware gating + num_languages: int = 100 + language_embedding_size: int = 128 + + def as_gpt3(self) -> GPT3Config: + config = dict(**vars(self)) + + # Moe + del config["is_moe"] + del config["moe_num_experts"] + del config["num_experts_per_tok"] + del config["moe_loss_weight"] + del config["moe_z_loss_weight"] + del config["moe_glu"] + + # language aware gating + del config["num_languages"] + del config["language_embedding_size"] + + if "_is_using_mup" in config: + del config["_is_using_mup"] + return GPT3Config(**config) + + def as_starcoder2(self) -> Starcoder2Config: + # same as gpt3 conversion above + config = dict(**vars(self)) + del config["sinusoidal_position_embedding"] + del config["use_spda"] + del config["position_embedding_offset"] + del config["act_pdrop"] + del config["scale_embedding"] + + # Moe + del config["is_moe"] + del config["moe_num_experts"] + del config["num_experts_per_tok"] + del config["moe_loss_weight"] + del config["moe_z_loss_weight"] + del config["moe_glu"] + + # language aware gating + del config["num_languages"] + del config["language_embedding_size"] + + if "_is_using_mup" in config: + del config["_is_using_mup"] + return Starcoder2Config( + grouped_query=True, + num_kv_heads=self.num_attention_heads, + use_rotary_embeddings=False, + **config, + ) + + @property + def n_inner(self): + return self.intermediate_size + + @property + def hidden_act(self): + return self.activation_function + + +NanotronConfigs = LlamaConfig | Starcoder2Config | GPT3Config | GPT3MoEConfig | GPT3LangMoEConfig diff --git a/src/nanotron/models/gpt3_langmoe.py b/src/nanotron/models/gpt3_langmoe.py new file mode 100644 index 00000000..f8fa8c07 --- /dev/null +++ b/src/nanotron/models/gpt3_langmoe.py @@ -0,0 +1,460 @@ +"""PyTorch GPT-3 MoE model.""" + +from contextlib import contextmanager +from typing import Dict, Optional, Union + +import torch +from torch import nn + +from nanotron import distributed as dist +from nanotron.config import GPT3Config, GPT3LangMoEConfig, ParallelismArgs +from nanotron.generation.generate_store import AttachableStore +from nanotron.models import gpt3 +from nanotron.models.gpt3 import ( + CausalSelfAttention, + GPT3ForTraining, + GPT3Model, + dropout_add_fused_train, +) +from nanotron.models.gpt3 import GPTBlock as GPT3Block +from nanotron.models.moe import ( + dLangMoE, +) +from nanotron.nn.layer_norm import TritonLayerNorm +from nanotron.parallel import ParallelContext +from nanotron.parallel.pipeline_parallel.block import PipelineBlock +from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer +from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode +from nanotron.parallel.tensor_parallel.nn import ( + TensorParallelColumnLinear, + TensorParallelEmbedding, +) +from nanotron.random import RandomStates, branch_random_state + + +@contextmanager +def replace_moe_decoder(gpt3config: GPT3LangMoEConfig): + orig = gpt3.PipelineBlock + try: + + def create_pp_block(module_builder, module_kwargs, **kwargs): + if module_builder is GPT3Block: + # GPT3's GPT module is trying to instantiate a GPT3 GPTBlock. + # Let's return a PipelineBlock with a GPT3Block instead. + # This also requires to replace starcoders2's config with gpt3's config. + module_kwargs["config"] = gpt3config + return orig( + module_builder=GPT3LangMoEBlock, + module_kwargs=module_kwargs, + **kwargs, + ) + # Else, they are setting up other modules, which we also want unchanged. + return orig(module_builder=module_builder, module_kwargs=module_kwargs, **kwargs) + + gpt3.PipelineBlock = create_pp_block + yield + finally: + gpt3.PipelineBlock = orig + + +@contextmanager +def replace_gpt3_moe_model(gpt3moeconfig: GPT3LangMoEConfig): + orig = gpt3.GPT3Model + try: + + def create_moe_model( + config: GPT3Config, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs], + random_states: RandomStates, + ): + return GPT3LangMoEModel(gpt3moeconfig, parallel_context, parallel_config, random_states) + + gpt3.GPT3Model = create_moe_model + yield + finally: + gpt3.GPT3Model = orig + + +class LanguageEmbedding(nn.Module, AttachableStore): + def __init__( + self, + tp_pg: dist.ProcessGroup, + config: GPT3LangMoEConfig, + parallel_config: Optional[ParallelismArgs], + ): + super().__init__() + self.language_embedding = TensorParallelEmbedding( + num_embeddings=config.num_languages, + embedding_dim=config.language_embedding_size, + pg=tp_pg, + mode=(parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE), + ) + self.pg = tp_pg + + def forward( + self, + lang_code: torch.Tensor, # [batch_size, 1] + ): + lang_code = lang_code.transpose(0, 1) + lang_emb = self.language_embedding(lang_code) + return {"lang_emb": lang_emb} + + +class GPT3LangMoEBlock(nn.Module): + def __init__( + self, + config: GPT3LangMoEConfig, + parallel_config: Optional[ParallelismArgs], + parallel_context: ParallelContext, + tp_pg: dist.ProcessGroup, + random_states: RandomStates, + layer_idx: int, + ): + super(GPT3LangMoEBlock, self).__init__() + self.ln_1 = TritonLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.attn = CausalSelfAttention( + config=config, + parallel_config=parallel_config, + tp_pg=tp_pg, + layer_idx=layer_idx, + ) + self.attn_dropout = config.attn_pdrop + + self.ln_2 = TritonLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) + + self.ff = dLangMoE( + config=config, + parallel_config=parallel_config, + parallel_context=parallel_context, + ) + self.ff_dropout = config.resid_pdrop + self.random_states = random_states + self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + + def forward( + self, + hidden_states: torch.Tensor | TensorPointer, + sequence_mask: torch.Tensor | TensorPointer, + lang_emb: torch.Tensor | TensorPointer, + aux_losses: Dict[str, Union[torch.Tensor, TensorPointer]], + ) -> dict[str, torch.Tensor | TensorPointer]: + + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + output = self.attn(hidden_states=hidden_states, sequence_mask=sequence_mask) + hidden_states = output["hidden_states"] + + if self.training: + with branch_random_state( + self.random_states, + "tp_synced", + enabled=self.tp_mode is TensorParallelLinearMode.ALL_REDUCE, + ): + hidden_states = dropout_add_fused_train(hidden_states, residual=residual, prob=self.attn_dropout) + else: + # No need for random state context manager + hidden_states = hidden_states + residual + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + mlp_output = self.ff(hidden_states=hidden_states, lang_hidden_states=lang_emb) + hidden_states = mlp_output["hidden_states"] + + for key, value in mlp_output.items(): + if key != "hidden_states": + aux_losses[key] = aux_losses[key] + value + + if self.training: + with branch_random_state( + self.random_states, + "tp_synced", + enabled=self.tp_mode is TensorParallelLinearMode.ALL_REDUCE, + ): + hidden_states = dropout_add_fused_train(hidden_states, residual=residual, prob=self.ff_dropout) + else: + # No need for random state context manager + hidden_states = hidden_states + residual + + return { + "hidden_states": hidden_states, + "sequence_mask": output["sequence_mask"], + "aux_losses": aux_losses, + } + + +class GPT3LangMoEModel(GPT3Model): + def __init__( + self, + config: GPT3LangMoEConfig, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs], + random_states: RandomStates, + ): + with replace_moe_decoder(config): + super().__init__(config.as_gpt3(), parallel_context, parallel_config, random_states) + + # need to adapt the decoder list because we pass the aux_losses around + self.decoder = nn.ModuleList( + [ + PipelineBlock( + p2p=self.p2p, + module_builder=GPT3LangMoEBlock, + module_kwargs={ + "config": config, + "parallel_config": parallel_config, + "tp_pg": parallel_context.tp_pg, + "random_states": random_states, + "parallel_context": parallel_context, + "layer_idx": layer_idx, + }, + module_input_keys={ + "hidden_states", + "sequence_mask", + "lang_emb", + "aux_losses", + }, + module_output_keys={"hidden_states", "sequence_mask", "aux_losses"}, + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) + + self.language_embeddings = PipelineBlock( + p2p=self.p2p, + module_builder=LanguageEmbedding, + module_kwargs={ + "tp_pg": parallel_context.tp_pg, + "config": config, + "parallel_config": parallel_config, + }, + module_input_keys={"lang_code"}, + module_output_keys={"lang_emb"}, + ) + + def forward( + self, + input_ids: torch.Tensor | TensorPointer, # [batch_size, seq_length] + input_mask: torch.Tensor | TensorPointer, # [batch_size, seq_length] + lang_code: torch.Tensor | TensorPointer, # [batch_size, 1] + aux_losses: Dict[str, Union[torch.Tensor, TensorPointer]], + ): + # all tensors are optional as most ranks don't need anything from the dataloader. + + input_embeds = ( + self.token_embeddings(input_ids=input_ids, input_mask=input_mask)["input_embeds"] * self.embed_scale + ) + # TODO: position_ids could be cached. + position_ids = torch.arange(input_ids.size(1), device="cuda").repeat(input_ids.size(0)).view(*input_ids.size()) + position_embeds = self.position_embeddings(position_ids=position_ids)["position_embeds"] + hidden_states = input_embeds + position_embeds + + # language embedding for MoE + lang_emb = self.language_embeddings(lang_code=lang_code)["lang_emb"] + + with branch_random_state( + self.random_states, + "tp_synced", + enabled=self.tp_mode == TensorParallelLinearMode.ALL_REDUCE, + ): + hidden_states = self.embeds_dropout(input=hidden_states)["hidden_states"] + + hidden_encoder_states = { + "hidden_states": hidden_states, + "sequence_mask": input_mask, + "aux_losses": aux_losses, + } + for encoder_block in self.decoder: + hidden_encoder_states = encoder_block(**hidden_encoder_states, lang_emb=lang_emb) + # return hidden_encoder_states["hidden_states"] + + hidden_states = self.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"] + + sharded_logits = self.lm_head(x=hidden_states)["logits"] + + fp32_sharded_logits = self.cast_to_fp32(x=sharded_logits)["output"] + + return { + "sharded_logits": fp32_sharded_logits, + "aux_losses": hidden_encoder_states["aux_losses"], + } + + +class GPT3LangMoEForTraining(GPT3ForTraining): + def __init__( + self, + config: GPT3LangMoEConfig, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs], + random_states: RandomStates, + ): + with replace_gpt3_moe_model(config): + super().__init__(config.as_gpt3(), parallel_context, parallel_config, random_states) + + self.config = config + + def forward( + self, + input_ids: Union[torch.Tensor, TensorPointer], + input_mask: Union[torch.Tensor, TensorPointer], + lang_code: Union[torch.Tensor, TensorPointer], # [batch_size, 1] TODO + label_ids: Union[torch.Tensor, TensorPointer], + label_mask: Union[torch.Tensor, TensorPointer], + ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: + # aux_losses are used for load balancing in case of MoEs + aux_losses = { + "load_balancing_loss": ( + torch.zeros(1, device=input_ids.device) + if not isinstance(input_ids, TensorPointer) + else TensorPointer(self.input_pp_rank) + ), + "z_loss": ( + torch.zeros(1, device=input_ids.device) + if not isinstance(input_ids, TensorPointer) + else TensorPointer(self.input_pp_rank) + ), + } + model_output = self.model( + input_ids=input_ids, + input_mask=input_mask, + lang_code=lang_code, + aux_losses=aux_losses, + ) + outputs = self.loss( + sharded_logits=model_output["sharded_logits"], + label_ids=label_ids, + label_mask=label_mask, + ) + + outputs["loss"] = torch.mean(outputs["sample_loss"]) + if isinstance(model_output["aux_losses"], dict): + for key, value in model_output["aux_losses"].items(): + outputs[key] = value + return outputs + + def get_block_compute_costs(self): + """Computes the compute cost of each block in the model so that we can do a better job of load balancing.""" + model_config = self.config + d_ff = model_config.n_inner if model_config.intermediate_size is not None else 4 * model_config.hidden_size + d_qkv = model_config.hidden_size // model_config.num_attention_heads + # active experts + routing + mlp_cost = ( + 2 * d_ff * model_config.hidden_size * model_config.num_experts_per_tok + + model_config.hidden_size * model_config.moe_num_experts + ) + att_cost = 4 * model_config.num_attention_heads * d_qkv * model_config.hidden_size + block_compute_costs = { + # CausalSelfAttention (qkv proj + attn out) + MLP + GPT3LangMoEBlock: att_cost + mlp_cost, + # This is the last lm_head + TensorParallelColumnLinear: model_config.vocab_size * model_config.hidden_size, + } + return block_compute_costs + + def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size): + """Get flops per second for a given model""" + world_size = self.parallel_context.world_pg.size() + model_flops, hardware_flops = get_flops( + num_layers=self.config.num_hidden_layers, + hidden_size=self.config.hidden_size, + num_heads=self.config.num_attention_heads, + vocab_size=self.config.vocab_size, + ffn_hidden_size=(self.config.n_inner if self.config.n_inner is not None else 4 * self.config.hidden_size), + seq_len=sequence_length, + batch_size=global_batch_size, + kv_channels=None, + glu_activation=False, + num_experts=self.config.moe_num_experts, + num_experts_per_tok=self.config.num_experts_per_tok, + ) + model_flops_per_s = model_flops / (iteration_time_in_sec * world_size * 1e12) + hardware_flops_per_s = hardware_flops / (iteration_time_in_sec * world_size * 1e12) + return model_flops_per_s, hardware_flops_per_s + + +def get_flops( + num_layers, + hidden_size, + num_heads, + vocab_size, + seq_len, + kv_channels=None, + ffn_hidden_size=None, + batch_size=1, + glu_activation=False, + num_experts=1, + num_experts_per_tok=1, +): + """Counts flops in an decoder-only model + Args: + num_layers: number of decoder layers + hidden_size: hidden size of the model + num_heads: number of heads in the model + kv_channels: hidden size of the key and value heads + ffn_hidden_size: hidden size of the FFN + vocab_size: size of the vocabulary + seq_len: sequence length of the decoder + batch_size: batch size + glu_activation: Whether to use GLU activation in FFN. Check T5 v1.1 for more info. + num_experts_per_tok: number of experts per token in the MoE layer + Returns: + model_flops: flops in the model (should be independent of the hardware and model implementation) + hardware_flops: flops in the hardware (actual flops performed on the hardware). Check 6.3 in https://arxiv.org/pdf/2205.05198.pdf + """ + + if kv_channels is None: + assert hidden_size % num_heads == 0 + kv_channels = hidden_size // num_heads + if ffn_hidden_size is None: + ffn_hidden_size = 4 * hidden_size + + # In the following we mark the reduced dimension with parentheses + # decoder + # self attention (MQA) + ## q projection + decoder_q_proj_flops_fwd = 2 * num_layers * batch_size * seq_len * (hidden_size) * num_heads * kv_channels + ## kv projection, shared across heads + decoder_kv_proj_flops_fwd = 2 * num_layers * batch_size * seq_len * (hidden_size) * 2 * kv_channels + ## qk logits + decoder_qk_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (kv_channels) * seq_len + ### SWA (sliding window attention / local attention) + # window_size = 4096 + # decoder_qk_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (kv_channels) * window_size + ## v logits + decoder_v_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (seq_len) * kv_channels + # decoder_v_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (window_size) * kv_channels + ## attn out + decoder_attn_out_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (kv_channels) * hidden_size + # FF + ## 1st layer + decoder_ffn_1_flops_fwd = 2 * num_layers * batch_size * seq_len * (hidden_size) * ffn_hidden_size + if glu_activation: + # 3 matmuls instead of 2 in FFN + # ref. https://arxiv.org/pdf/2002.05202.pdf + # Used for example in T5 v1.1 + decoder_ffn_1_flops_fwd = 4 * num_layers * batch_size * seq_len * (hidden_size) * ffn_hidden_size + ## 2nd layer + decoder_ffn_2_flops_fwd = 2 * num_layers * batch_size * seq_len * (ffn_hidden_size) * hidden_size + # MoE router + decoder_ffn_router_flops_fwd = 2 * num_layers * batch_size * seq_len * (hidden_size) * num_experts + + decoder_flops_fwd = ( + decoder_q_proj_flops_fwd + + decoder_kv_proj_flops_fwd + + decoder_qk_logits_flops_fwd + + decoder_v_logits_flops_fwd + + decoder_attn_out_flops_fwd + + decoder_ffn_1_flops_fwd * num_experts_per_tok + + decoder_ffn_2_flops_fwd * num_experts_per_tok + + decoder_ffn_router_flops_fwd + ) + + # lm head + lm_head_flops_fwd = 2 * batch_size * seq_len * (hidden_size) * vocab_size + + # the bwd pass requires double the flops in case of matmuls to calculate the gradients with respect to + # both input and weight tensors + model_flops = 3 * (decoder_flops_fwd + lm_head_flops_fwd) # 1 for fwd + 2 for bwd + + hardware_flops = model_flops # TODO @nouamanetazi: This is a placeholder for now + return model_flops, hardware_flops diff --git a/src/nanotron/models/moe.py b/src/nanotron/models/moe.py index f7bb07bd..8b25a5f3 100644 --- a/src/nanotron/models/moe.py +++ b/src/nanotron/models/moe.py @@ -157,14 +157,68 @@ def forward(self, hidden_states: torch.Tensor): """ # Compute the expert scores and assignments. # TODO: support sequence parallelism - batch_size, sequence_length, _ = hidden_states.size() + sequence_length, batch_size, _ = hidden_states.size() x = hidden_states.view(-1, self.config.hidden_size) router_logits, expert_weights, top_experts = self.gate(x) # Compute the experts. x, lbl_loss, z_loss = self.experts(x, router_logits, expert_weights, top_experts) return { - "hidden_states": x.reshape(batch_size, sequence_length, -1), + "hidden_states": x.reshape(sequence_length, batch_size, -1), + "load_balancing_loss": lbl_loss, + "z_loss": z_loss, + } + + +class dLangMoE(torch.nn.Module): + def __init__( + self, + config: Config, + parallel_context: "ParallelContext", + parallel_config: Optional[ParallelismArgs], + ): + super().__init__() + self.config = config + self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + if self.tp_mode == TensorParallelLinearMode.REDUCE_SCATTER: + logging.warn_once( + logger=logger, + msg="TensorParallelLinearMode.REDUCE_SCATTER is still experimental for MoEs. Use at your own risk.", + rank=0, + ) + + # Token router. + self.gate = LearnedRouter(config, meta_dim=config.language_embedding_size) + + # Expert computation helper. + self.experts = ParallelDroplessMLP( + config, + use_bias=False, + parallel_context=parallel_context, + parallel_config=parallel_config, + ) + + def forward(self, hidden_states: torch.Tensor, lang_hidden_states: torch.Tensor): + """ + Args: + x: input tensor of shape [sequence_length, batch_size, hidden_size] + lang_hidden_states: input tensor of shape [1, batch_size, hidden_size] + """ + # Compute the expert scores and assignments. + # TODO: support sequence parallelism + sequence_length, batch_size, _ = hidden_states.size() + x = hidden_states.view(-1, self.config.hidden_size) + + # Repeat the language embedding to go to [batch_size * sequence_length, hidden_size] + lang_x = lang_hidden_states.repeat(sequence_length, 1, 1) + lang_x = lang_x.view(-1, self.config.language_embedding_size) + + router_logits, expert_weights, top_experts = self.gate(x, lang_x) + + # Compute the experts. + x, lbl_loss, z_loss = self.experts(x, router_logits, expert_weights, top_experts) + return { + "hidden_states": x.reshape(sequence_length, batch_size, -1), "load_balancing_loss": lbl_loss, "z_loss": z_loss, } @@ -172,12 +226,16 @@ def forward(self, hidden_states: torch.Tensor): # Adapted from megablocks.layers.router.LearnedRouter class LearnedRouter(torch.nn.Module): - def __init__(self, config: Config): + def __init__(self, config: Config, meta_dim: int = 0): super().__init__() - self.layer = torch.nn.Linear(config.hidden_size, config.moe_num_experts, bias=False) + self.layer = torch.nn.Linear(config.hidden_size + meta_dim, config.moe_num_experts, bias=False) self.config = config - def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def forward( + self, x: torch.Tensor, meta_x: Optional[torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if meta_x is not None: + x = torch.cat([x, meta_x], dim=-1) router_logits = self.layer(x) # (batch * sequence_length, n_experts) scores = F.softmax(router_logits, dim=-1, dtype=torch.float32) # TODO: fuse? @@ -652,6 +710,7 @@ def forward(self, hidden_states, topo): # [seq_length, batch_size, hidden_dim] hidden_states = self.w2(self.act(merged_states)) return hidden_states + class GLU(MLP): def __init__( self, @@ -676,11 +735,12 @@ def __init__( expert_parallel_size=self.expert_pg_size, ) - def forward(self, x, topo): + def forward(self, hidden_states, topo): merged_states = self.w1(hidden_states) hidden_states = self.w2(self.act(merged_states) * self.w3(hidden_states)) return hidden_states + def inclusive_cumsum(x, dim): scalar = ops.inclusive_cumsum(x, dim) return scalar.view(1) if not len(scalar.size()) else scalar @@ -718,4 +778,4 @@ def forward(self, x, topo): x1 = self.sdd(x, self.w1.module.weight, topo) x2 = self.sdd(x, self.w3.module.weight, topo) x = stk.ops.mul(act_fn(x1, self.act), x2) - return self.dsd(x, self.w2.module.weight) \ No newline at end of file + return self.dsd(x, self.w2.module.weight) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 91b9a29b..1f5286bf 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -57,6 +57,7 @@ from nanotron.models import NanotronModel, build_model from nanotron.models.base import check_model_has_grad from nanotron.models.gpt3 import GPT3ForTraining +from nanotron.models.gpt3_langmoe import GPT3LangMoEForTraining from nanotron.models.gpt3_moe import GPT3MoEForTraining from nanotron.models.llama import LlamaForTraining, RotaryEmbedding from nanotron.models.starcoder2 import Starcoder2ForTraining @@ -108,6 +109,7 @@ "Starcoder2Config": Starcoder2ForTraining, "GPT3Config": GPT3ForTraining, "GPT3MoEConfig": GPT3MoEForTraining, + "GPT3LangMoEConfig": GPT3LangMoEForTraining, } try: