From adc4aeccc93831509dc910d5ea580898616b4fda Mon Sep 17 00:00:00 2001 From: jahatef Date: Tue, 20 Aug 2024 16:57:20 +0000 Subject: [PATCH] mamba fixes and cleaning --- configs/mamba/mamba-1.4B.yml | 68 ++++++++++++++++++++++++++- configs/mamba/mamba-130M.yml | 69 ++++++++++++++++++++++++++- configs/mamba/mamba-2.8B.yml | 68 ++++++++++++++++++++++++++- configs/mamba/mamba-370M.yml | 69 ++++++++++++++++++++++++++- configs/mamba/mamba-790M.yml | 70 +++++++++++++++++++++++++++- megatron/model/mamba/mamba.py | 7 +-- megatron/neox_arguments/arguments.py | 2 +- 7 files changed, 339 insertions(+), 14 deletions(-) diff --git a/configs/mamba/mamba-1.4B.yml b/configs/mamba/mamba-1.4B.yml index 2898a72fd..eae467d0e 100644 --- a/configs/mamba/mamba-1.4B.yml +++ b/configs/mamba/mamba-1.4B.yml @@ -19,5 +19,71 @@ "mamba_inner_func_fusion": true, # supersedes scan or conv fusion "activation": "silu", - "output_layer_init_method": "single_residual_scaled_normal", + # init methods + "init_method": "small_init", + "output_layer_init_method": "single_residual_scaled_normal", + + # optimizer settings + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.0002, + "betas": [0.9, 0.95], + "eps": 1.0e-8, + } + }, + "min_lr": 0.00002, + + # for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training + "zero_optimization": { + "stage": 1, + "allgather_partitions": True, + "allgather_bucket_size": 500000000, + "overlap_comm": True, + "reduce_scatter": True, + "reduce_bucket_size": 500000000, + "contiguous_gradients": True, + }, + + # batch / data settings + "train_micro_batch_size_per_gpu": 4, + "data_impl": "mmap", + + # activation checkpointing + "checkpoint_activations": true, + "checkpoint_num_layers": 1, + "partition_activations": true, + "synchronize_each_layer": true, + + # regularization + "gradient_clipping": 1.0, + "weight_decay": 0.1, + "hidden_dropout": 0, + "attention_dropout": 0, + + # precision settings + "fp16": { + "fp16": true, + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + + # misc. training settings + "train_iters": 320000, + "lr_decay_iters": 320000, + "distributed_backend": "nccl", + "lr_decay_style": "cosine", + "warmup": 0.01, + "checkpoint_factor": 10000, + "eval_interval": 1000, + "eval_iters": 10, + + # logging + "log_interval": 1, + "steps_per_print": 10, + "keep_last_n_checkpoints": 4, + "wall_clock_breakdown": true, } diff --git a/configs/mamba/mamba-130M.yml b/configs/mamba/mamba-130M.yml index d9a6ab92e..7187048e6 100644 --- a/configs/mamba/mamba-130M.yml +++ b/configs/mamba/mamba-130M.yml @@ -19,5 +19,70 @@ "mamba_inner_func_fusion": true, # supersedes scan or conv fusion "activation": "silu", - "output_layer_init_method": "single_residual_scaled_normal", -} + # init methods + "init_method": "small_init", + "output_layer_init_method": "single_residual_scaled_normal", + + + # optimizer settings + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.0006, + "betas": [0.9, 0.95], + "eps": 1.0e-8, + } + }, + "min_lr": 0.00006, + + # for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training + "zero_optimization": { + "stage": 1, + "allgather_partitions": True, + "allgather_bucket_size": 500000000, + "overlap_comm": True, + "reduce_scatter": True, + "reduce_bucket_size": 500000000, + "contiguous_gradients": True, + }, + + # batch / data settings + "train_micro_batch_size_per_gpu": 4, + "data_impl": "mmap", + + # activation checkpointing + "checkpoint_activations": true, + "checkpoint_num_layers": 1, + "partition_activations": true, + "synchronize_each_layer": true, + + # regularization + "gradient_clipping": 1.0, + "weight_decay": 0.1, + "hidden_dropout": 0.0, + "attention_dropout": 0.0, + + # precision settings + "fp16": { + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + + # misc. training settings + "train_iters": 320000, + "lr_decay_iters": 320000, + "distributed_backend": "nccl", + "lr_decay_style": "cosine", + "warmup": 0.01, + "checkpoint_factor": 10000, + "eval_interval": 1000, + "eval_iters": 10, + + # logging + "log_interval": 100, + "steps_per_print": 10, + "keep_last_n_checkpoints": 4, + "wall_clock_breakdown": true, diff --git a/configs/mamba/mamba-2.8B.yml b/configs/mamba/mamba-2.8B.yml index 1aacb264b..d5afef368 100644 --- a/configs/mamba/mamba-2.8B.yml +++ b/configs/mamba/mamba-2.8B.yml @@ -19,5 +19,71 @@ "mamba_inner_func_fusion": true, # supersedes scan or conv fusion "activation": "silu", - "output_layer_init_method": "single_residual_scaled_normal", + # init methods + "init_method": "small_init", + "output_layer_init_method": "single_residual_scaled_normal", + + # optimizer settings + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00016, + "betas": [0.9, 0.95], + "eps": 1.0e-8, + } + }, + "min_lr": 0.000016, + + # for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training + "zero_optimization": { + "stage": 1, + "allgather_partitions": True, + "allgather_bucket_size": 500000000, + "overlap_comm": True, + "reduce_scatter": True, + "reduce_bucket_size": 500000000, + "contiguous_gradients": True, + }, + + # batch / data settings + "train_micro_batch_size_per_gpu": 4, + "data_impl": "mmap", + + # activation checkpointing + "checkpoint_activations": true, + "checkpoint_num_layers": 1, + "partition_activations": true, + "synchronize_each_layer": true, + + # regularization + "gradient_clipping": 1.0, + "weight_decay": 0.1, + "hidden_dropout": 0, + "attention_dropout": 0, + + # precision settings + "fp16": { + "fp16": true, + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + + # misc. training settings + "train_iters": 320000, + "lr_decay_iters": 320000, + "distributed_backend": "nccl", + "lr_decay_style": "cosine", + "warmup": 0.01, + "checkpoint_factor": 10000, + "eval_interval": 1000, + "eval_iters": 10, + + # logging + "log_interval": 100, + "steps_per_print": 10, + "keep_last_n_checkpoints": 4, + "wall_clock_breakdown": true, } diff --git a/configs/mamba/mamba-370M.yml b/configs/mamba/mamba-370M.yml index 5e5a78cca..0058f1c0e 100644 --- a/configs/mamba/mamba-370M.yml +++ b/configs/mamba/mamba-370M.yml @@ -12,12 +12,77 @@ "norm": "rmsnorm", "rms_norm_epsilon": 1.0e-5, - "attention_config": [[["mamba"], 64]], + "attention_config": [[["mamba"], 48]], "mamba_selective_scan_fusion": true, "mamba_causal_conv_fusion": true, "mamba_inner_func_fusion": true, # supersedes scan or conv fusion "activation": "silu", - "output_layer_init_method": "single_residual_scaled_normal", + # init methods + "init_method": "small_init", + "output_layer_init_method": "single_residual_scaled_normal", + + # optimizer settings + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.0003, + "betas": [0.9, 0.95], + "eps": 1.0e-8, + } + }, + "min_lr": 0.00003, + + # for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training + "zero_optimization": { + "stage": 1, + "allgather_partitions": True, + "allgather_bucket_size": 500000000, + "overlap_comm": True, + "reduce_scatter": True, + "reduce_bucket_size": 500000000, + "contiguous_gradients": True, + }, + # batch / data settings + "train_micro_batch_size_per_gpu": 4, + "data_impl": "mmap", + + # activation checkpointing + "checkpoint_activations": true, + "checkpoint_num_layers": 1, + "partition_activations": true, + "synchronize_each_layer": true, + + # regularization + "gradient_clipping": 1.0, + "weight_decay": 0.1, + "hidden_dropout": 0, + "attention_dropout": 0, + + # precision settings + "fp16": { + "fp16": true, + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + + # misc. training settings + "train_iters": 320000, + "lr_decay_iters": 320000, + "distributed_backend": "nccl", + "lr_decay_style": "cosine", + "warmup": 0.01, + "checkpoint_factor": 10000, + "eval_interval": 1000, + "eval_iters": 10, + + # logging + "log_interval": 100, + "steps_per_print": 10, + "keep_last_n_checkpoints": 4, + "wall_clock_breakdown": true, } diff --git a/configs/mamba/mamba-790M.yml b/configs/mamba/mamba-790M.yml index fcd324d9d..4aef7e813 100644 --- a/configs/mamba/mamba-790M.yml +++ b/configs/mamba/mamba-790M.yml @@ -12,12 +12,78 @@ "norm": "rmsnorm", "rms_norm_epsilon": 1.0e-5, - "attention_config": [[["mamba"], 64]], + "attention_config": [[["mamba"], 48]], "mamba_selective_scan_fusion": true, "mamba_causal_conv_fusion": true, "mamba_inner_func_fusion": true, # supersedes scan or conv fusion "activation": "silu", - "output_layer_init_method": "single_residual_scaled_normal", + # init methods + "init_method": "small_init", + "output_layer_init_method": "single_residual_scaled_normal", + + # optimizer settings + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00025, + "betas": [0.9, 0.999], + "eps": 1.0e-8, + } + }, + "min_lr": 0.000025, + + # for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training + "zero_optimization": { + "stage": 1, + "allgather_partitions": True, + "allgather_bucket_size": 500000000, + "overlap_comm": True, + "reduce_scatter": True, + "reduce_bucket_size": 500000000, + "contiguous_gradients": True, + }, + + # batch / data settings + "train_micro_batch_size_per_gpu": 4, + "data_impl": "mmap", + + # activation checkpointing + "checkpoint_activations": true, + "checkpoint_num_layers": 1, + "partition_activations": true, + "synchronize_each_layer": true, + + # regularization + "gradient_clipping": 1.0, + "weight_decay": 0.1, + "hidden_dropout": 0, + "attention_dropout": 0, + + # precision settings + "fp16": { + "fp16": true, + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + + # misc. training settings + "train_iters": 320000, + "lr_decay_iters": 320000, + "distributed_backend": "nccl", + "lr_decay_style": "cosine", + "warmup": 0.01, + "checkpoint_factor": 10000, + "eval_interval": 1000, + "eval_iters": 10, + + # logging + "log_interval": 100, + "steps_per_print": 10, + "keep_last_n_checkpoints": 4, + "wall_clock_breakdown": true, } diff --git a/megatron/model/mamba/mamba.py b/megatron/model/mamba/mamba.py index d5d6b336f..3d8243ab2 100644 --- a/megatron/model/mamba/mamba.py +++ b/megatron/model/mamba/mamba.py @@ -3,7 +3,6 @@ import torch import torch.nn as nn import torch.nn.functional as F - try: from mamba_ssm.ops.selective_scan_interface import ( selective_scan_ref, @@ -13,10 +12,8 @@ from causal_conv1d import causal_conv1d_fn import einops except ModuleNotFoundError: - print( - "Unable to import Mamba kernels. Install them from our requirements/requirements-mamba.txt, or directly from https://github.com/state-spaces/mamba" - ) - pass + assert False, "Unable to import Mamba kernels. Install them from our requirements/requirements-mamba.txt, \ + or directly from https://github.com/state-spaces/mamba" from megatron.model.norms import get_norm from megatron import mpu diff --git a/megatron/neox_arguments/arguments.py b/megatron/neox_arguments/arguments.py index 054689eda..bccb02910 100644 --- a/megatron/neox_arguments/arguments.py +++ b/megatron/neox_arguments/arguments.py @@ -1185,7 +1185,7 @@ def validate_values(self): return False # Checks. - if self.hidden_size % self.num_attention_heads != 0: + if self.hidden_size % self.num_attention_heads != 0 and not ("mamba" in self.attention_config): error_message = ( self.__class__.__name__ + ".validate_values() hidden_size must be divisible by num_attention_heads"