From 0d921f79e5c1f73f5e4c9418afcf5de05196f1e4 Mon Sep 17 00:00:00 2001 From: lintangsutawika Date: Fri, 1 Dec 2023 03:02:02 +0000 Subject: [PATCH 1/5] changed ordering for setting up norm_factor --- megatron/model/transformer.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 63f4122e2..eeb141fa1 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -295,14 +295,14 @@ def __init__( bias=neox_args.use_bias_in_attn_linear, ) - coeff = None - self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) - if self.apply_query_key_layer_scaling: - coeff = max(1, self.layer_number) - self.norm_factor *= coeff - if neox_args.use_mup: self.norm_factor = self.hidden_size_per_attention_head + else: + coeff = None + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + if self.apply_query_key_layer_scaling: + coeff = max(1, self.layer_number) + self.norm_factor *= coeff self.rpe = rpe @@ -956,6 +956,12 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=Non else: logits_parallel = F.linear(input_parallel, word_embeddings_weight, bias) + + # if self.neox_args.use_mup: + # # Since we're using pipeline parallelism, we can't directly use MuReadout. Instead, use this workaround that does the same thing as MuReadout. + # # https://github.com/microsoft/mup/issues/6#issuecomment-1082156274 + # logits_parallel /= self.tied_modules.embed.word_embeddings.weight.infshape.width_mult() + # Gather if needed. if parallel_output: return logits_parallel From abee54daef5a0ca7e27a7f143ca8d93111dea54c Mon Sep 17 00:00:00 2001 From: github-actions Date: Fri, 1 Dec 2023 03:02:58 +0000 Subject: [PATCH 2/5] Update NeoXArgs docs automatically --- configs/neox_arguments.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index bc2e8fc57..aa7b72d29 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -111,7 +111,7 @@ Logging Arguments - **git_hash**: str - Default = 2da1083 + Default = 0d921f7 current git hash of repository From a08c3efbf1688e9e46ea654b2f0a8195a0ae404e Mon Sep 17 00:00:00 2001 From: lintangsutawika Date: Fri, 1 Dec 2023 03:49:53 +0000 Subject: [PATCH 3/5] updated muP args to the minimum required --- megatron/model/gpt2_model.py | 12 +++----- megatron/model/init_functions.py | 43 ++++++---------------------- megatron/model/transformer.py | 9 ++---- megatron/neox_arguments/neox_args.py | 33 +++++++-------------- 4 files changed, 26 insertions(+), 71 deletions(-) diff --git a/megatron/model/gpt2_model.py b/megatron/model/gpt2_model.py index 2725614cd..5fd70c49f 100644 --- a/megatron/model/gpt2_model.py +++ b/megatron/model/gpt2_model.py @@ -119,6 +119,9 @@ def __init__( self.init_method, self.output_layer_init_method = get_init_methods( self.neox_args ) + self.init_method, self.output_layer_init_method = get_init_methods( + self.neox_args + ) self.__topology__ = topology self.specs = [] @@ -268,16 +271,9 @@ def init_specs(self): def _logits_helper(embedding, lm_output): """Just a wrapper to massage inputs/outputs from pipeline.""" - if self.neox_args.use_mup: - # Since we're using pipeline parallelism, we can't directly use MuReadout. Instead, use this workaround that does the same thing as MuReadout. - # https://github.com/microsoft/mup/issues/6#issuecomment-1082156274 - lm_output = ( - lm_output - / self.tied_modules.embed.word_embeddings.weight.infshape.width_mult() - ) logits = parallel_lm_logits( - lm_output, embedding.word_embeddings_weight, self.parallel_output + lm_output, embedding.word_embeddings_weight, self.parallel_output, self.neox_args ) return logits diff --git a/megatron/model/init_functions.py b/megatron/model/init_functions.py index 11bcdc310..ff4c36b53 100644 --- a/megatron/model/init_functions.py +++ b/megatron/model/init_functions.py @@ -16,41 +16,22 @@ import torch -try: - import mup -except ImportError: - pass - -def init_method_normal(sigma, use_mup_outer=False, mup_init_scale=1.0): +def init_method_normal(sigma): """Init method based on N(0, sigma).""" - def init_(tensor, use_mup=use_mup_outer): - if use_mup: - mup.init.normal_(tensor, mean=0.0, std=sigma) - with torch.no_grad(): - tensor.mul_(mup_init_scale) - return tensor - else: - return torch.nn.init.normal_(tensor, mean=0.0, std=sigma) + def init_(tensor): + return torch.nn.init.normal_(tensor, mean=0.0, std=sigma) return init_ -def scaled_init_method_normal( - sigma, num_layers, use_mup_outer=False, mup_init_scale=1.0 -): +def scaled_init_method_normal(sigma, num_layers): """Init method based on N(0, sigma/sqrt(2*num_layers).""" std = sigma / math.sqrt(2.0 * num_layers) - def init_(tensor, use_mup=use_mup_outer): - if use_mup: - mup.init.normal_(tensor, mean=0.0, std=std) - with torch.no_grad(): - tensor.mul_(mup_init_scale) - return tensor - else: - return torch.nn.init.normal_(tensor, mean=0.0, std=std) + def init_(tensor): + return torch.nn.init.normal_(tensor, mean=0.0, std=std) return init_ @@ -169,21 +150,15 @@ def init_(tensor, use_mup=use_mup_outer): def get_init_methods(args): - if args.use_mup: - try: - import mup - except ModuleNotFoundError: - print("Please install mup https://github.com/microsoft/mup") - raise Exception - def _get(name): if name == "normal": return init_method_normal( - args.init_method_std, args.use_mup, args.mup_init_scale + sigma=args.init_method_std*args.mup_init_scale ) elif name == "scaled_normal": return scaled_init_method_normal( - args.init_method_std, args.num_layers, args.use_mup, args.mup_init_scale + sigma=args.init_method_std*args.mup_init_scale, + num_layers=args.num_layers ) elif name == "orthogonal": return orthogonal_init_method(args.use_mup, args.mup_init_scale) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index eeb141fa1..0785561cb 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -945,7 +945,7 @@ def forward(self, args): return self.norm(args) -def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None): +def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None, args=None): """LM logits using word embedding weights.""" # Parallel logits. input_parallel = mpu.copy_to_model_parallel_region(input_) @@ -956,11 +956,8 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=Non else: logits_parallel = F.linear(input_parallel, word_embeddings_weight, bias) - - # if self.neox_args.use_mup: - # # Since we're using pipeline parallelism, we can't directly use MuReadout. Instead, use this workaround that does the same thing as MuReadout. - # # https://github.com/microsoft/mup/issues/6#issuecomment-1082156274 - # logits_parallel /= self.tied_modules.embed.word_embeddings.weight.infshape.width_mult() + if args is not None and args.use_mup: + logits_parallel *= args.mup_output_logit_multiplier # Gather if needed. if parallel_output: diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index 957960832..58780881b 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -263,6 +263,7 @@ class NeoXArgsModel(NeoXArgsTemplate): init_method_std: float = 0.02 """ Standard deviation of the zero mean normal distribution used for weight initialization. + When using muP this is the base std """ apply_query_key_layer_scaling: bool = False @@ -427,6 +428,7 @@ class NeoXArgsOptimizer(NeoXArgsTemplate): lr: float = None """ Max Learning rate during training + When using muP, this is the base learning rate """ @@ -1015,7 +1017,7 @@ class NeoXArgsTraining(NeoXArgsTemplate): use_mup: bool = False """ - Whether to use Microsoft's Mup https://github.com/microsoft/mup + Whether to use muP """ coord_check: bool = False @@ -1033,35 +1035,20 @@ class NeoXArgsTraining(NeoXArgsTemplate): Path to the base shapes to save to/load from """ - mup_init_scale: float = 1.0 + mup_emb: int = 1 """ - Initialization scale: All the parameters are multiplied by this value + Embedding output multiplier """ - mup_attn_temp: float = 1.0 + mup_m_width: int = 1 """ - Attention temperature: Reciprocal of the multiplier applied to the input to attention softmax + Manually set the layer width multiplier (d_model/d_model,base) """ - mup_output_temp: float = 1.0 + mup_d_model_base: int = 64 """ - Output temperature: Reciprocal of the multiplier applied to the input to softmax that - produces the distribution over output tokens. - """ - - mup_embedding_mult: float = 1.0 - """ - Scalar by which we multiply the output of the embedding layer - """ - - mup_rp_embedding_mult: float = 1.0 - """ - Scalar by which we multiply vectors representing relative position - """ - - mup_width_scale: int = 2 - """ - What to scale width by when creating the delta model for mup + d_model,base + Proxy (base) model's layer width """ From c35e8309a6f5b1e73f8d1dd888c23c481011b818 Mon Sep 17 00:00:00 2001 From: lintangsutawika Date: Fri, 1 Dec 2023 03:55:29 +0000 Subject: [PATCH 4/5] calculate m_width --- megatron/training.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/megatron/training.py b/megatron/training.py index ed9c0bcd0..0dea5ab17 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -439,11 +439,9 @@ def get_model(neox_args, use_cache=False): neox_args.use_mup = old_use_mup if neox_args.use_mup: - try: - import mup - except ModuleNotFoundError: - print("Please install mup https://github.com/microsoft/mup") - raise Exception + + if neox_args.mup_m_width == 1: + neox_args.mup_m_width = neox_args.hidden_size / neox_args.mup_d_model_base base_shapes = f"{neox_args.base_shapes_file}.{torch.distributed.get_rank()}" From 81fdc4d1f7b7558aa55c97ad9adc04cd2e7bf693 Mon Sep 17 00:00:00 2001 From: github-actions Date: Fri, 1 Dec 2023 09:30:44 +0000 Subject: [PATCH 5/5] Update NeoXArgs docs automatically --- configs/neox_arguments.md | 50 +++++++++++---------------------------- 1 file changed, 14 insertions(+), 36 deletions(-) diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index aa7b72d29..93c0328fe 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -111,7 +111,7 @@ Logging Arguments - **git_hash**: str - Default = 0d921f7 + Default = 2d127df current git hash of repository @@ -452,6 +452,7 @@ Model Arguments Default = 0.02 Standard deviation of the zero mean normal distribution used for weight initialization. + When using muP this is the base std @@ -663,6 +664,7 @@ Optimizer Arguments Default = None Max Learning rate during training + When using muP, this is the base learning rate @@ -1521,7 +1523,7 @@ Training Arguments Default = False - Whether to use Microsoft's Mup https://github.com/microsoft/mup + Whether to use muP @@ -1549,52 +1551,28 @@ Training Arguments -- **mup_init_scale**: float +- **mup_emb**: int - Default = 1.0 - - Initialization scale: All the parameters are multiplied by this value - - - -- **mup_attn_temp**: float - - Default = 1.0 - - Attention temperature: Reciprocal of the multiplier applied to the input to attention softmax - - - -- **mup_output_temp**: float - - Default = 1.0 - - Output temperature: Reciprocal of the multiplier applied to the input to softmax that - produces the distribution over output tokens. - - - -- **mup_embedding_mult**: float - - Default = 1.0 + Default = 1 - Scalar by which we multiply the output of the embedding layer + Embedding output multiplier -- **mup_rp_embedding_mult**: float +- **mup_m_width**: int - Default = 1.0 + Default = 1 - Scalar by which we multiply vectors representing relative position + Manually set the layer width multiplier (d_model/d_model,base) -- **mup_width_scale**: int +- **mup_d_model_base**: int - Default = 2 + Default = 64 - What to scale width by when creating the delta model for mup + d_model,base + Proxy (base) model's layer width