Skip to content

Commit

Permalink
Add back some codes about hetero training and recomputation (#141)
Browse files Browse the repository at this point in the history
Co-authored-by: zhaoyinglia <[email protected]>
  • Loading branch information
aoyulong and zhaoyinglia authored Jun 12, 2024
1 parent d05ca66 commit 3540eb4
Show file tree
Hide file tree
Showing 13 changed files with 431 additions and 108 deletions.
26 changes: 26 additions & 0 deletions examples/aquila/conf/config_hetero.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
defaults:
- train: demo_hetero
- _self_

experiment:
exp_name: aquila2
exp_dir: ./outputs
task:
type: train
backend: megatron
entrypoint: ./flagscale/train/train_aquila.py
runner:
hostfile: xxxx # Please replace with your actual hostfile path
rdzv_backend: "static" # hetero training only supports static
envs:
CUDA_VISIBLE_DEVICES: 0,1,2,3,4,5,6,7
CUDA_DEVICE_MAX_CONNECTIONS: 1
cmds:
before_start: "ulimit -n 1048576"
after_stop: ""

action: run

hydra:
run:
dir: ${experiment.exp_dir}/hydra
80 changes: 80 additions & 0 deletions examples/aquila/conf/train/demo_hetero.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
system:
tensor_model_parallel_size: 2
pipeline_model_parallel_size: 4
disable_bias_linear: True
use_flash_attn: True
sequence_parallel: True
use_distributed_optimizer: True
use_mcore_models: true
transformer_impl: transformer_engine
hetero:
hetero_mode: "pp"
hetero_pipeline_stages: [4, 2, 2, 4, 4]
recompute:
recompute_granularity: "full"
recompute_method: "uniform"
recompute_num_layers: 1
recompute_granularity_per_stage: [1, 0, 2, 1, 1, 1]
recompute_method_per_stage: [1, 0, 2, 0, 1, 1]
recompute_num_layers_per_stage: [1, 2, 2, 1, 1, 2]
precision:
bf16: True
attention_softmax_in_fp32: True
accumulate_allreduce_grads_in_fp32: True
logging:
log_interval: 1
log_throughput: true
tensorboard_log_interval: 1
wandb_project: "aquila2"
wandb_exp_name: "test"
checkpoint:
save_interval: 1000


model:
num_layers: 12
hidden_size: 4096
num_attention_heads: 32
seq_length: 2048
max_position_embeddings: 2048
norm_epsilon: 1e-5
use_rotary_position_embeddings: true
no_position_embedding: true
swiglu: true
multiple_of: 256
normalization: RMSNorm
rotary_interleaved_patch: true
untie_embeddings_and_output_weights: true
init_method_std: 0.0165
attention_dropout: 0.0
hidden_dropout: 0.0
weight_decay: 0.1
clip_grad: 1.0
train_samples: 100000
global_batch_size: 32
micro_batch_size: 1
# rampup_batch_size: [32, 32, 2000000]
seed: 42

optimizer:
lr: 2e-4
weight_decay: 0.01
adam_beta1: 0.9
adam_beta2: 0.95
lr_scheduler:
lr: 1.5e-4
min_lr: 1.5e-5
lr_warmup_samples: 500
lr_decay_style: cosine

data:
data_path: xxxx # Please replace with your actual data path
split: 1
tokenizer:
tokenizer_type: xxxx # Please replace with your actual tokenizer type
tokenizer_path: xxxx # Please replace with your actual tokenizer path
vocab_file: null
merge_file: null
special_tokens_file: null
vocab_size: xxxx # Please replace with your actual vocab size
make_vocab_size_divisible_by: 64
2 changes: 1 addition & 1 deletion flagscale/train/global_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def get_extra_valid_datasets():


def set_extra_valid_datasets(extra_valid_datasets):
"""Initialize heterogenous context."""""
"""Set extra_valid datasets."""""
global _GLOBAL_EXTRA_VALID_DATASETS
_GLOBAL_EXTRA_VALID_DATASETS = extra_valid_datasets

Expand Down
14 changes: 7 additions & 7 deletions flagscale/train/hetero/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,13 +133,13 @@ def _initialize_distributed():
timeout=timedelta(minutes=args.distributed_timeout_minutes),
)

if args.num_process_meshes == None:
if args.hetero_mode is not None:
# Build the heterogenous context after torch.distributed is initialized and
# before model parallel is initialized.
set_hetero_context(args)
if torch.distributed.get_rank() == 0:
print(get_hetero_context(), flush=True)
# if args.num_process_meshes == None:
# if args.hetero_mode is not None:
# # Build the heterogenous context after torch.distributed is initialized and
# # before model parallel is initialized.
# set_hetero_context(args)
# if torch.distributed.get_rank() == 0:
# print(get_hetero_context(), flush=True)

# Set the tensor model-parallel, pipeline model-parallel, and
# data-parallel communicators.
Expand Down
13 changes: 13 additions & 0 deletions megatron/megatron/core/model_parallel_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,19 @@ class ModelParallelConfig:
the user adds a level 1 timer that is not called by all ranks.
"""

###################
# Heterogeneous Training
###################
hetero_mode: str = None
"""Specifies the mode of heterogeneous training. This could be only 'pp'."""

hetero_pipeline_stages: list = None
"""Defines the pipeline stages for different device types. Each element represents the number of pipeline stages for one device type."""

hetero_pipeline_stage_splits: list = None
"""A list of lists, each sublist contains numbers of layers to be processed in the corresponding pipeline stages for one device type."""


def __post_init__(self):
""" Python dataclass method that is used to modify attributes after initialization.
See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more details.
Expand Down
67 changes: 66 additions & 1 deletion megatron/megatron/core/transformer/transformer_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,14 @@ def get_num_layers_to_build(config: TransformerConfig) -> int:
# Non-interleaved pipeline parallelism:
# Each stage gets a contiguous set of layers.

num_layers_to_build = num_layers_per_pipeline_rank
if config.hetero_mode == "pp":
pipeline_rank = parallel_state.get_pipeline_model_parallel_rank()
pipeline_stages = [
item for sublist in config.hetero_pipeline_stages for item in sublist
]
num_layers_to_build = pipeline_stages[pipeline_rank]
else:
num_layers_to_build = num_layers_per_pipeline_rank

return num_layers_to_build

Expand Down Expand Up @@ -253,6 +260,64 @@ def checkpoint_handler(forward_func):
packed_seq_params,
)

if self.config.recompute_method_per_stage != None:
if self.config.virtual_pipeline_model_parallel_size != None:
if (
self.config.recompute_method_per_stage[
parallel_state.get_virtual_pipeline_model_parallel_rank()
* self.config.pipeline_model_parallel_size
+ parallel_state.get_pipeline_model_parallel_rank()
]
== 0
):
self.config.recompute_method = 'uniform'
elif (
self.config.recompute_method_per_stage[
parallel_state.get_virtual_pipeline_model_parallel_rank()
* self.config.pipeline_model_parallel_size
+ parallel_state.get_pipeline_model_parallel_rank()
]
== 1
):
self.config.recompute_method = 'block'
else:
if (
self.config.recompute_method_per_stage[
parallel_state.get_pipeline_model_parallel_rank()
]
== 0
):
self.config.recompute_method = 'uniform'
elif (
self.config.recompute_method_per_stage[
parallel_state.get_pipeline_model_parallel_rank()
]
== 1
):
self.config.recompute_method = 'block'

if self.config.recompute_num_layers_per_stage != None:
if self.config.virtual_pipeline_model_parallel_size != None:
self.config.recompute_num_layers = self.config.recompute_num_layers_per_stage[
parallel_state.get_virtual_pipeline_model_parallel_rank()
* self.config.pipeline_model_parallel_size
+ parallel_state.get_pipeline_model_parallel_rank()
]
else:
self.config.recompute_num_layers = self.config.recompute_num_layers_per_stage[
parallel_state.get_pipeline_model_parallel_rank()
]

if (
self.config.recompute_granularity_per_stage != None
and self.config.recompute_granularity_per_stage[
parallel_state.get_pipeline_model_parallel_rank()
]
== 0
):
self.recompute_granularity = None
self.recompute_method = None

if self.config.recompute_method == 'uniform':
# Uniformly divide the total number of Transformer layers and checkpoint
# the input activation of each divided chunk.
Expand Down
10 changes: 9 additions & 1 deletion megatron/megatron/core/transformer/transformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,6 @@ class TransformerConfig(ModelParallelConfig):
# activation recomputation
####################
recompute_granularity: str = None
recompute_granularity: str = None
"""Determines which type of activation recompute to use. Megatron-core supports 'selective'
activation checkpointing where only the memory intensive part of attention is checkpointed.
These memory intensive activations are also less compute intensive which makes activation
Expand All @@ -182,6 +181,15 @@ class TransformerConfig(ModelParallelConfig):
the number of transformer layers to recompute within each pipeline stage. Must be None for
'selective' activation checkpointing."""

recompute_granularity_per_stage: list = None
"""Same as recompute_granularity but for each stage."""

recompute_method_per_stage: list = None
"""Same as recompute_method but for each stage."""

recompute_num_layers_per_stage: list = None
"""Same as recompute_num_layers but for each stage."""

distribute_saved_activations: bool = None
"""If True, distribute recomputed activations across the model parallel group."""

Expand Down
8 changes: 7 additions & 1 deletion megatron/megatron/core/transformer/transformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,13 @@ def _get_layer_offset(self):
else:
# Each stage gets a contiguous set of layers.
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
offset = pipeline_rank * num_layers_per_pipeline_rank
if self.config.hetero_mode == "pp":
pipeline_stages = [
item for sublist in self.config.hetero_pipeline_stages for item in sublist
]
offset = sum(([0] + pipeline_stages)[: pipeline_rank + 1])
else:
offset = pipeline_rank * num_layers_per_pipeline_rank
else:
offset = 0

Expand Down
65 changes: 57 additions & 8 deletions megatron/megatron/legacy/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1426,10 +1426,10 @@ def _get_num_layers(args, model_type, is_decoder=False):
return num_layers


def _get_layer_info(args):
assert args.hetero_mode == "pp", "Only pipeline parallelism is supported."
def _get_layer_info(config):
assert config.hetero_mode == "pp", "Only pipeline parallelism is supported."
pipeline_rank = mpu.get_pipeline_model_parallel_rank()
pipeline_stages = [item for sublist in args.hetero_pipeline_stages for item in sublist]
pipeline_stages = [item for sublist in config.hetero_pipeline_stages for item in sublist]
offset = sum(([0] + pipeline_stages)[: pipeline_rank + 1])
num_layers = pipeline_stages[pipeline_rank]
torch.distributed.barrier()
Expand Down Expand Up @@ -1482,9 +1482,57 @@ def __init__(self, config,
self.retro_add_retriever = args.retro_add_retriever

# Store activation checkpoiting flag.
self.recompute_granularity = config.recompute_granularity
self.recompute_method = config.recompute_method
self.recompute_num_layers = config.recompute_num_layers
if config.recompute_method_per_stage != None:
if config.virtual_pipeline_model_parallel_size != None:
if (
config.recompute_method_per_stage[
mpu.get_virtual_pipeline_model_parallel_rank()
* config.pipeline_model_parallel_size
+ mpu.get_pipeline_model_parallel_rank()
]
== 0
):
self.recompute_method = 'uniform'
elif (
config.recompute_method_per_stage[
mpu.get_virtual_pipeline_model_parallel_rank()
* config.pipeline_model_parallel_size
+ mpu.get_pipeline_model_parallel_rank()
]
== 1
):
self.recompute_method = 'block'
else:
if config.recompute_method_per_stage[mpu.get_pipeline_model_parallel_rank()] == 0:
self.recompute_method = 'uniform'
elif config.recompute_method_per_stage[mpu.get_pipeline_model_parallel_rank()] == 1:
self.recompute_method = 'block'
else:
self.recompute_method = config.recompute_method

if config.recompute_num_layers_per_stage != None:
if config.virtual_pipeline_model_parallel_size != None:
self.recompute_num_layers = config.recompute_num_layers_per_stage[
mpu.get_virtual_pipeline_model_parallel_rank()
* config.pipeline_model_parallel_size
+ mpu.get_pipeline_model_parallel_rank()
]
else:
self.recompute_num_layers = config.recompute_num_layers_per_stage[
mpu.get_pipeline_model_parallel_rank()
]
else:
self.recompute_num_layers = config.recompute_num_layers

if (
config.recompute_granularity_per_stage != None
and config.recompute_granularity_per_stage[mpu.get_pipeline_model_parallel_rank()] == 0
):
self.recompute_granularity = None
self.recompute_method = None
else:
self.recompute_granularity = config.recompute_granularity

self.distribute_saved_activations = \
config.distribute_saved_activations and not config.sequence_parallel

Expand Down Expand Up @@ -1645,10 +1693,10 @@ def build_layer(layer_number):
num_ranks_in_enc = args.pipeline_model_parallel_split_rank
offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers
else:
if args.hetero_mode != "pp":
if config.hetero_mode != "pp":
offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
else:
offset, self.num_layers = _get_layer_info(args)
offset, self.num_layers = _get_layer_info(config)

if self.num_layers == 0:
# When a standalone embedding stage is used (e.g.,
Expand All @@ -1661,6 +1709,7 @@ def build_layer(layer_number):
# disconnect the input tensor from the output tensor.
self.num_layers = 1
self.layers = torch.nn.ModuleList([ NoopTransformerLayer(1) ])
self.recompute_granularity = None
else:
self.layers = torch.nn.ModuleList(
[build_layer(i + 1 + offset) for i in range(self.num_layers)])
Expand Down
Loading

0 comments on commit 3540eb4

Please sign in to comment.