Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mamba fixes and cleaning #1262

Merged
merged 3 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 67 additions & 1 deletion configs/mamba/mamba-1.4B.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,71 @@
"mamba_inner_func_fusion": true, # supersedes scan or conv fusion
"activation": "silu",

"output_layer_init_method": "single_residual_scaled_normal",
# init methods
"init_method": "small_init",
"output_layer_init_method": "single_residual_scaled_normal",

# optimizer settings
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.0002,
"betas": [0.9, 0.95],
"eps": 1.0e-8,
}
},
"min_lr": 0.00002,

# for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training
"zero_optimization": {
"stage": 1,
"allgather_partitions": True,
"allgather_bucket_size": 500000000,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 500000000,
"contiguous_gradients": True,
},

# batch / data settings
"train_micro_batch_size_per_gpu": 4,
"data_impl": "mmap",

# activation checkpointing
"checkpoint_activations": true,
"checkpoint_num_layers": 1,
"partition_activations": true,
"synchronize_each_layer": true,

# regularization
"gradient_clipping": 1.0,
"weight_decay": 0.1,
"hidden_dropout": 0,
"attention_dropout": 0,

# precision settings
"fp16": {
"fp16": true,
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},

# misc. training settings
"train_iters": 320000,
"lr_decay_iters": 320000,
"distributed_backend": "nccl",
"lr_decay_style": "cosine",
"warmup": 0.01,
"checkpoint_factor": 10000,
"eval_interval": 1000,
"eval_iters": 10,

# logging
"log_interval": 1,
"steps_per_print": 10,
"keep_last_n_checkpoints": 4,
"wall_clock_breakdown": true,
}
69 changes: 67 additions & 2 deletions configs/mamba/mamba-130M.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,70 @@
"mamba_inner_func_fusion": true, # supersedes scan or conv fusion
"activation": "silu",

"output_layer_init_method": "single_residual_scaled_normal",
}
# init methods
"init_method": "small_init",
"output_layer_init_method": "single_residual_scaled_normal",


# optimizer settings
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.0006,
"betas": [0.9, 0.95],
"eps": 1.0e-8,
}
},
"min_lr": 0.00006,

# for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training
"zero_optimization": {
"stage": 1,
"allgather_partitions": True,
"allgather_bucket_size": 500000000,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 500000000,
"contiguous_gradients": True,
},

# batch / data settings
"train_micro_batch_size_per_gpu": 4,
"data_impl": "mmap",

# activation checkpointing
"checkpoint_activations": true,
"checkpoint_num_layers": 1,
"partition_activations": true,
"synchronize_each_layer": true,

# regularization
"gradient_clipping": 1.0,
"weight_decay": 0.1,
"hidden_dropout": 0.0,
"attention_dropout": 0.0,

# precision settings
"fp16": {
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},

# misc. training settings
"train_iters": 320000,
"lr_decay_iters": 320000,
"distributed_backend": "nccl",
"lr_decay_style": "cosine",
"warmup": 0.01,
"checkpoint_factor": 10000,
"eval_interval": 1000,
"eval_iters": 10,

# logging
"log_interval": 100,
"steps_per_print": 10,
"keep_last_n_checkpoints": 4,
"wall_clock_breakdown": true,
68 changes: 67 additions & 1 deletion configs/mamba/mamba-2.8B.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,71 @@
"mamba_inner_func_fusion": true, # supersedes scan or conv fusion
"activation": "silu",

"output_layer_init_method": "single_residual_scaled_normal",
# init methods
"init_method": "small_init",
"output_layer_init_method": "single_residual_scaled_normal",

# optimizer settings
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00016,
"betas": [0.9, 0.95],
"eps": 1.0e-8,
}
},
"min_lr": 0.000016,

# for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training
"zero_optimization": {
"stage": 1,
"allgather_partitions": True,
"allgather_bucket_size": 500000000,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 500000000,
"contiguous_gradients": True,
},

# batch / data settings
"train_micro_batch_size_per_gpu": 4,
"data_impl": "mmap",

# activation checkpointing
"checkpoint_activations": true,
"checkpoint_num_layers": 1,
"partition_activations": true,
"synchronize_each_layer": true,

# regularization
"gradient_clipping": 1.0,
"weight_decay": 0.1,
"hidden_dropout": 0,
"attention_dropout": 0,

# precision settings
"fp16": {
"fp16": true,
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},

# misc. training settings
"train_iters": 320000,
"lr_decay_iters": 320000,
"distributed_backend": "nccl",
"lr_decay_style": "cosine",
"warmup": 0.01,
"checkpoint_factor": 10000,
"eval_interval": 1000,
"eval_iters": 10,

# logging
"log_interval": 100,
"steps_per_print": 10,
"keep_last_n_checkpoints": 4,
"wall_clock_breakdown": true,
}
69 changes: 67 additions & 2 deletions configs/mamba/mamba-370M.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,77 @@
"norm": "rmsnorm",
"rms_norm_epsilon": 1.0e-5,

"attention_config": [[["mamba"], 64]],
"attention_config": [[["mamba"], 48]],

"mamba_selective_scan_fusion": true,
"mamba_causal_conv_fusion": true,
"mamba_inner_func_fusion": true, # supersedes scan or conv fusion
"activation": "silu",

"output_layer_init_method": "single_residual_scaled_normal",
# init methods
"init_method": "small_init",
"output_layer_init_method": "single_residual_scaled_normal",

# optimizer settings
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.0003,
"betas": [0.9, 0.95],
"eps": 1.0e-8,
}
},
"min_lr": 0.00003,

# for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training
"zero_optimization": {
"stage": 1,
"allgather_partitions": True,
"allgather_bucket_size": 500000000,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 500000000,
"contiguous_gradients": True,
},
# batch / data settings
"train_micro_batch_size_per_gpu": 4,
"data_impl": "mmap",

# activation checkpointing
"checkpoint_activations": true,
"checkpoint_num_layers": 1,
"partition_activations": true,
"synchronize_each_layer": true,

# regularization
"gradient_clipping": 1.0,
"weight_decay": 0.1,
"hidden_dropout": 0,
"attention_dropout": 0,

# precision settings
"fp16": {
"fp16": true,
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},

# misc. training settings
"train_iters": 320000,
"lr_decay_iters": 320000,
"distributed_backend": "nccl",
"lr_decay_style": "cosine",
"warmup": 0.01,
"checkpoint_factor": 10000,
"eval_interval": 1000,
"eval_iters": 10,

# logging
"log_interval": 100,
"steps_per_print": 10,
"keep_last_n_checkpoints": 4,
"wall_clock_breakdown": true,
}
70 changes: 68 additions & 2 deletions configs/mamba/mamba-790M.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,78 @@
"norm": "rmsnorm",
"rms_norm_epsilon": 1.0e-5,

"attention_config": [[["mamba"], 64]],
"attention_config": [[["mamba"], 48]],

"mamba_selective_scan_fusion": true,
"mamba_causal_conv_fusion": true,
"mamba_inner_func_fusion": true, # supersedes scan or conv fusion
"activation": "silu",

"output_layer_init_method": "single_residual_scaled_normal",
# init methods
"init_method": "small_init",
"output_layer_init_method": "single_residual_scaled_normal",

# optimizer settings
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00025,
"betas": [0.9, 0.999],
"eps": 1.0e-8,
}
},
"min_lr": 0.000025,

# for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training
"zero_optimization": {
"stage": 1,
"allgather_partitions": True,
"allgather_bucket_size": 500000000,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 500000000,
"contiguous_gradients": True,
},

# batch / data settings
"train_micro_batch_size_per_gpu": 4,
"data_impl": "mmap",

# activation checkpointing
"checkpoint_activations": true,
"checkpoint_num_layers": 1,
"partition_activations": true,
"synchronize_each_layer": true,

# regularization
"gradient_clipping": 1.0,
"weight_decay": 0.1,
"hidden_dropout": 0,
"attention_dropout": 0,

# precision settings
"fp16": {
"fp16": true,
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},

# misc. training settings
"train_iters": 320000,
"lr_decay_iters": 320000,
"distributed_backend": "nccl",
"lr_decay_style": "cosine",
"warmup": 0.01,
"checkpoint_factor": 10000,
"eval_interval": 1000,
"eval_iters": 10,

# logging
"log_interval": 100,
"steps_per_print": 10,
"keep_last_n_checkpoints": 4,
"wall_clock_breakdown": true,
}
5 changes: 2 additions & 3 deletions megatron/model/mamba/mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@
from causal_conv1d import causal_conv1d_fn
import einops
except ModuleNotFoundError:
print(
"Unable to import Mamba kernels. Install them from our requirements/requirements-mamba.txt, or directly from https://github.com/state-spaces/mamba"
)
print( "Unable to import Mamba kernels. Install them from our requirements/requirements-mamba.txt, \
or directly from https://github.com/state-spaces/mamba")
pass

from megatron.model.norms import get_norm
Expand Down
Loading
Loading