From 4aee67c7449e263c12670c9c0d549ded25a3112e Mon Sep 17 00:00:00 2001 From: "chenxun.p" <759046501@qq.com> Date: Wed, 4 Dec 2024 15:18:58 +0800 Subject: [PATCH] feat(pipeline): add non-p2p-comm support --- configs/_base_/models/internlm2_1B.py | 1 + internlm/core/scheduler/comm/p2p.py | 104 ++++++++++----- .../core/scheduler/no_pipeline_scheduler.py | 2 +- .../core/scheduler/pipeline_scheduler_1f1b.py | 121 +++++++++++------- internlm/core/trainer_builder.py | 2 +- internlm/initialize/launch.py | 10 +- internlm/train/pipeline.py | 1 + 7 files changed, 158 insertions(+), 83 deletions(-) diff --git a/configs/_base_/models/internlm2_1B.py b/configs/_base_/models/internlm2_1B.py index 1b36da91..2b9275c9 100644 --- a/configs/_base_/models/internlm2_1B.py +++ b/configs/_base_/models/internlm2_1B.py @@ -63,6 +63,7 @@ 1. size: int, the size of pipeline parallel. 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler, defaults to False. + 4. batch_p2p_comm: bool, enable/disable batch p2p communication, defaults to False. weight parallel (dict): 1. size: int, the size of weight parallel. 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. diff --git a/internlm/core/scheduler/comm/p2p.py b/internlm/core/scheduler/comm/p2p.py index 54fb587c..6236953f 100644 --- a/internlm/core/scheduler/comm/p2p.py +++ b/internlm/core/scheduler/comm/p2p.py @@ -44,6 +44,15 @@ def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) -> return tensor_chunk_shape, chunk_tensor +def _p2p_func(_comm_op, _obj, _comm_rank): + if gpc.config.parallel.pipeline.batch_p2p_comm is True: + op_or_handle = dist.P2POp(_comm_op, _obj, _comm_rank) + else: + op_or_handle = _comm_op(_obj, _comm_rank) + + return op_or_handle + + def create_recv_buffer_with_shapes(recv_shapes, dtype, scatter_gather_tensors): if isinstance(recv_shapes, torch.Size): recv_chunk_shape, recv_split = _get_tensor_shape(recv_shapes, scatter_gather_tensors) @@ -78,12 +87,10 @@ def process_object_to_send(object_send, scatter_gather_tensors): def filling_ops_queue(obj, comm_op, comm_rank, ops_queue): if isinstance(obj, torch.Tensor): - op_to_add = dist.P2POp(comm_op, obj, comm_rank) - ops_queue.append(op_to_add) + ops_queue.append(_p2p_func(comm_op, obj, comm_rank)) else: for tensor_to_comm in obj: - op_to_add = dist.P2POp(comm_op, tensor_to_comm, comm_rank) - ops_queue.append(op_to_add) + ops_queue.append(_p2p_func(comm_op, tensor_to_comm, comm_rank)) def _communicate( @@ -156,23 +163,42 @@ def _communicate( object_send_next = process_object_to_send(object_send_next, scatter_gather_tensors) ops = [] - if object_send_prev is not None: - filling_ops_queue(object_send_prev, dist.isend, prev_rank, ops) - if tensor_recv_prev is not None: - filling_ops_queue(tensor_recv_prev, dist.irecv, prev_rank, ops) + if gpc.get_local_rank(ParallelMode.PIPELINE) % 2 == 0: + if object_send_next is not None: + filling_ops_queue(object_send_next, dist.isend, next_rank, ops) - if tensor_recv_next is not None: - filling_ops_queue(tensor_recv_next, dist.irecv, next_rank, ops) + if tensor_recv_prev is not None: + filling_ops_queue(tensor_recv_prev, dist.irecv, prev_rank, ops) + + if object_send_prev is not None: + filling_ops_queue(object_send_prev, dist.isend, prev_rank, ops) + + if tensor_recv_next is not None: + filling_ops_queue(tensor_recv_next, dist.irecv, next_rank, ops) + else: + if tensor_recv_prev is not None: + filling_ops_queue(tensor_recv_prev, dist.irecv, prev_rank, ops) + + if object_send_next is not None: + filling_ops_queue(object_send_next, dist.isend, next_rank, ops) + + if tensor_recv_next is not None: + filling_ops_queue(tensor_recv_next, dist.irecv, next_rank, ops) + + if object_send_prev is not None: + filling_ops_queue(object_send_prev, dist.isend, prev_rank, ops) - if object_send_next is not None: - filling_ops_queue(object_send_next, dist.isend, next_rank, ops) if len(ops) > 0: - reqs = dist.batch_isend_irecv(ops) - for req in reqs: - req.wait() - # To protect against race condition when using batch_isend_irecv(). - internlm_accelerator.synchronize() + if gpc.config.parallel.pipeline.batch_p2p_comm is True: + reqs = dist.batch_isend_irecv(ops) + for req in reqs: + req.wait() + # To protect against race condition when using batch_isend_irecv(). + internlm_accelerator.synchronize() + else: + for req in ops: + req.wait() if recv_prev and recv_prev_split: if isinstance(tensor_recv_prev, torch.Tensor): @@ -265,29 +291,47 @@ def _communicate_async( object_send_next = process_object_to_send(object_send_next, scatter_gather_tensors) ops = [] - if object_send_prev is not None: - filling_ops_queue(object_send_prev, dist.isend, prev_rank, ops) - if tensor_recv_prev is not None: - filling_ops_queue(tensor_recv_prev, dist.irecv, prev_rank, ops) + if gpc.get_local_rank(ParallelMode.PIPELINE) % 2 == 0: + if object_send_next is not None: + filling_ops_queue(object_send_next, dist.isend, next_rank, ops) - if tensor_recv_next is not None: - filling_ops_queue(tensor_recv_next, dist.irecv, next_rank, ops) + if tensor_recv_prev is not None: + filling_ops_queue(tensor_recv_prev, dist.irecv, prev_rank, ops) - if object_send_next is not None: - filling_ops_queue(object_send_next, dist.isend, next_rank, ops) + if object_send_prev is not None: + filling_ops_queue(object_send_prev, dist.isend, prev_rank, ops) - if len(ops) > 0: + if tensor_recv_next is not None: + filling_ops_queue(tensor_recv_next, dist.irecv, next_rank, ops) + else: + if tensor_recv_prev is not None: + filling_ops_queue(tensor_recv_prev, dist.irecv, prev_rank, ops) + + if object_send_next is not None: + filling_ops_queue(object_send_next, dist.isend, next_rank, ops) + + if tensor_recv_next is not None: + filling_ops_queue(tensor_recv_next, dist.irecv, next_rank, ops) + + if object_send_prev is not None: + filling_ops_queue(object_send_prev, dist.isend, prev_rank, ops) + + if len(ops) > 0 and gpc.config.parallel.pipeline.batch_p2p_comm is True: reqs = dist.batch_isend_irecv(ops) # return and do other things yield if len(ops) > 0: - for req in reqs: # pylint: disable=E0601 - req.wait() - # To protect against race condition when using batch_isend_irecv(). - internlm_accelerator.synchronize() + if gpc.config.parallel.pipeline.batch_p2p_comm is True: + for req in reqs: + req.wait() + # To protect against race condition when using batch_isend_irecv(). + internlm_accelerator.synchronize() + else: + for req in ops: + req.wait() if recv_prev and recv_prev_split: if isinstance(tensor_recv_prev, torch.Tensor): diff --git a/internlm/core/scheduler/no_pipeline_scheduler.py b/internlm/core/scheduler/no_pipeline_scheduler.py index 84b94dbf..3ec14067 100644 --- a/internlm/core/scheduler/no_pipeline_scheduler.py +++ b/internlm/core/scheduler/no_pipeline_scheduler.py @@ -230,7 +230,7 @@ def forward_backward_step( outputs, labels = None, None # Compatible for non-moe - if hasattr(gpc.config.model, "num_experts"): + if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1: return outputs, labels, loss, moe_loss else: return outputs, labels, loss diff --git a/internlm/core/scheduler/pipeline_scheduler_1f1b.py b/internlm/core/scheduler/pipeline_scheduler_1f1b.py index 4864c77f..1aef1a86 100644 --- a/internlm/core/scheduler/pipeline_scheduler_1f1b.py +++ b/internlm/core/scheduler/pipeline_scheduler_1f1b.py @@ -197,7 +197,7 @@ def _call_engine(engine, data): # pylint: disable=W0237 def load_batch(self, engine, data_iter): # Pipeline schedule just puts data in memory, - batch_data, actual_batch_size = engine.load_batch(data_iter, to_gpu=False) + batch_data, actual_batch_size = engine.load_batch(data_iter, to_gpu=True) # Even if 'use_flash_attn' is False, the data seen when the 'load_batch' is called is still packed, # because internlm's current train dataset is packed, even using dummy data. @@ -289,7 +289,7 @@ def _forward_step( data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data) self._call_hooks("before_forward", data) - if hasattr(gpc.config.model, "num_experts"): + if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1: # moe is used output_obj, moe_losses = self._call_engine(engine.model, data) else: @@ -309,17 +309,18 @@ def _forward_step( accum_loss.add_(loss_reduced.detach()) output_obj = loss_reduced - moe_loss = ( - sum(moe_losses) * gpc.config.loss.moe_loss_coeff # pylint: disable=E0606 - if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1 - else torch.tensor(0.0, device=get_current_device(), dtype=gpc.config.model.get("dtype")) - ) - # the moe_loss is computed among the "tensor" group if sequence parallel is enabled, so we need to do allreduce - if gpc.config.parallel.sequence_parallel or gpc.config.parallel.expert.no_tp: - dist.all_reduce(moe_loss, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR)) - moe_loss.div_(gpc.get_world_size(ParallelMode.TENSOR)) - moe_loss /= self.num_microbatches - accum_moe_loss.add_(moe_loss.detach()) + if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1: + moe_loss = sum(moe_losses) * gpc.config.loss.moe_loss_coeff + + # the moe_loss is computed among the "tensor" group if sequence parallel is enabled, + # so we need to do allreduce + if gpc.config.parallel.sequence_parallel or gpc.config.parallel.expert.no_tp: + dist.all_reduce(moe_loss, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR)) + moe_loss.div_(gpc.get_world_size(ParallelMode.TENSOR)) + moe_loss /= self.num_microbatches + accum_moe_loss.add_(moe_loss.detach()) + else: + moe_loss = None return output_obj, moe_loss @@ -413,7 +414,11 @@ def _forward_only_step(self, engine, return_loss=True, return_output_label=True) if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True) else None ) - accum_moe_loss = torch.zeros(1, device=get_current_device()) + + if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1: + accum_moe_loss = torch.zeros(1, device=get_current_device()) + else: + accum_moe_loss = None # Used for tensor meta information communication forward_recv_shapes = self.tensor_shape @@ -456,8 +461,8 @@ def _forward_only_step(self, engine, return_loss=True, return_output_label=True) if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1: dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE)) - if accum_loss is not None: - accum_loss += accum_moe_loss + if accum_loss is not None: + accum_loss += accum_moe_loss return output, label, accum_loss, accum_moe_loss @@ -514,7 +519,11 @@ def _forward_backward_step(self, engine, return_loss=True, return_output_label=T if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True) else None ) - accum_moe_loss = torch.zeros(1, device=get_current_device()) + + if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1: + accum_moe_loss = torch.zeros(1, device=get_current_device()) + else: + accum_moe_loss = None # Used for tensor meta information communication forward_recv_shapes = self.tensor_shape @@ -660,8 +669,8 @@ def _forward_backward_step(self, engine, return_loss=True, return_output_label=T if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1: dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE)) - if accum_loss is not None: - accum_loss += accum_moe_loss + if accum_loss is not None: + accum_loss += accum_moe_loss return output, label, accum_loss, accum_moe_loss @@ -699,7 +708,7 @@ def forward_backward_step(self, engine, data_iter, forward_only=False, return_lo ) # Compatible for non-moe - if hasattr(gpc.config.model, "num_experts"): + if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1: return output, label, accum_loss, accum_moe_loss else: return output, label, accum_loss @@ -776,6 +785,7 @@ def __init__( self._output_obj_grads = [[] for _ in range(num_chunks)] self._moe_losses = [[] for _ in range(num_chunks)] + self._preload_micro_data = [None for _ in range(self.num_microbatches)] self._input_obj_shapes = [self.tensor_shape for _ in range(num_chunks)] self._output_obj_shapes = [None for _ in range(num_chunks)] self._send_tensor_shape_flags = [self.tensor_shape is None for _ in range(num_chunks)] @@ -799,26 +809,37 @@ def _clear_state(self) -> None: self._output_obj_grads = [[] for _ in range(self._num_chunks)] self._moe_losses = [[] for _ in range(self._num_chunks)] + self._preload_micro_data = [None for _ in range(self.num_microbatches)] self._input_obj_shapes = [self.tensor_shape for _ in range(self._num_chunks)] self._output_obj_shapes = [None for _ in range(self._num_chunks)] self._send_tensor_shape_flags = [self.tensor_shape is None for _ in range(self._num_chunks)] def load_batch(self, engine, data_iter): super().load_batch(engine, data_iter) + + for mbs in range(self.num_microbatches): + micro_batch_data, micro_batch_label = self._load_micro_batch( + data=self.batch_data, + label=self.batch_label, + offset=mbs * self.bsz_stride, + bsz_stride=self.bsz_stride, + ) + + if self.data_process_func: + micro_batch_data, micro_batch_label = self.data_process_func(micro_batch_data, micro_batch_label) + + micro_batch_data["label"] = micro_batch_label + self._preload_micro_data[mbs] = micro_batch_data + # overwrite microbatch_offset, since model chunks load the same microbatch, and should tract the offset self.microbatch_offset = [0 for _ in range(self._num_chunks)] def load_micro_batch(self, model_chunk_id): - micro_batch_data, micro_batch_label = self._load_micro_batch( - data=self.batch_data, - label=self.batch_label, - offset=self.microbatch_offset[model_chunk_id], - bsz_stride=self.bsz_stride, - ) - if self.data_process_func: - micro_batch_data, micro_batch_label = self.data_process_func(micro_batch_data, micro_batch_label) - micro_batch_data["label"] = micro_batch_label - self.microbatch_offset[model_chunk_id] += self.bsz_stride + offset = self.microbatch_offset[model_chunk_id] + assert self._preload_micro_data[offset] is not None, "preload micro batch data is None" + + micro_batch_data = self._preload_micro_data[offset] + self.microbatch_offset[model_chunk_id] += 1 result = move_to_device(micro_batch_data) return result @@ -849,7 +870,7 @@ def _forward_step(self, engine, chunk_id, input_obj=None): data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data) self._call_hooks("before_forward", data) - if hasattr(gpc.config.model, "num_experts"): + if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1: output_obj, moe_losses = self._call_engine(engine.model[chunk_id], data) else: output_obj = self._call_engine(engine.model[chunk_id], data) @@ -872,18 +893,19 @@ def _forward_step(self, engine, chunk_id, input_obj=None): self._accum_loss.add_(loss_reduced.detach()) output_obj = loss_reduced - moe_loss = ( - sum(moe_losses) * gpc.config.loss.moe_loss_coeff # pylint: disable=E0606 - if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1 - else torch.tensor(0.0, device=get_current_device(), dtype=gpc.config.model.get("dtype")) - ) - # the moe_loss is computed among the "tensor" group if sequence parallel is enabled, so we need to do allreduce - if gpc.config.parallel.sequence_parallel or gpc.config.parallel.expert.no_tp: - dist.all_reduce(moe_loss, op=dist.ReduceOp.AVG, group=gpc.get_group(ParallelMode.TENSOR)) - moe_loss /= self.num_microbatches + if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1: + moe_loss = sum(moe_losses) * gpc.config.loss.moe_loss_coeff - if self._accum_moe_loss is not None: - self._accum_moe_loss.add_(moe_loss.detach()) + # the moe_loss is computed among the "tensor" group if sequence parallel is enabled, + # so we need to do allreduce + if gpc.config.parallel.sequence_parallel or gpc.config.parallel.expert.no_tp: + dist.all_reduce(moe_loss, op=dist.ReduceOp.AVG, group=gpc.get_group(ParallelMode.TENSOR)) + moe_loss /= self.num_microbatches + + if self._accum_moe_loss is not None: + self._accum_moe_loss.add_(moe_loss.detach()) + else: + moe_loss = None self._output_objs[chunk_id].append(output_obj) self._moe_losses[chunk_id].append(moe_loss) @@ -1394,7 +1416,9 @@ def forward_backward_step(self, engine, data_iter, forward_only=False, return_lo if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True): self._accum_loss = torch.zeros(1, device=get_current_device()) - self._accum_moe_loss = torch.zeros(1, device=get_current_device()) + + if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1: + self._accum_moe_loss = torch.zeros(1, device=get_current_device()) if return_output_label: self._return_tensors = [] @@ -1409,18 +1433,19 @@ def forward_backward_step(self, engine, data_iter, forward_only=False, return_lo else: output, label = (None, None) + accum_loss = self._accum_loss + if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1: dist.all_reduce(self._accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE)) - accum_moe_loss = self._accum_moe_loss + accum_moe_loss = self._accum_moe_loss - accum_loss = self._accum_loss - if accum_loss is not None: - accum_loss += self._accum_moe_loss + if accum_loss is not None: + accum_loss += self._accum_moe_loss self._clear_state() # Compatible for non-moe - if hasattr(gpc.config.model, "num_experts"): + if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1: return output, label, accum_loss, accum_moe_loss else: return output, label, accum_loss diff --git a/internlm/core/trainer_builder.py b/internlm/core/trainer_builder.py index d0ef284d..c2ea75ff 100644 --- a/internlm/core/trainer_builder.py +++ b/internlm/core/trainer_builder.py @@ -306,7 +306,7 @@ def _load_and_prepare_batch(self, batch_count: int, train_iter): def _forward_backward(self, batch): self.zero_grad() - if hasattr(gpc.config.model, "num_experts"): + if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1: _, _, loss, moe_loss = self.execute_schedule( batch, forward_only=False, return_loss=True, return_output_label=False ) diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 1ac8ef31..7919e27b 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -87,9 +87,6 @@ def args_sanity_check(): if "pipeline" not in gpc.config.parallel: gpc.config.parallel._add_item("pipeline", dict(size=1, interleaved_overlap=False, mode="1F1B")) - if isinstance(gpc.config.parallel.pipeline, dict) and "mode" not in gpc.config.parallel.pipeline: - gpc.config.parallel.pipeline._add_item("mode", "1F1B") - if "tensor" not in gpc.config.parallel: gpc.config.parallel._add_item("tensor", dict(size=1, mode=TensorParallelMode.mtp.name)) @@ -104,9 +101,16 @@ def args_sanity_check(): if isinstance(gpc.config.parallel.pipeline, int): pp = gpc.config.parallel.pipeline + gpc.config.parallel._add_item("pipeline", dict(size=pp, interleaved_overlap=False)) else: pp = gpc.config.parallel.pipeline.size + if isinstance(gpc.config.parallel.pipeline, dict) and "mode" not in gpc.config.parallel.pipeline: + gpc.config.parallel.pipeline._add_item("mode", "1F1B") + + if "batch_p2p_comm" not in gpc.config.parallel.pipeline: + gpc.config.parallel.pipeline["batch_p2p_comm"] = False + if isinstance(gpc.config.parallel.pipeline, dict): gpc.config.parallel.pipeline["mode"] = gpc.config.parallel.pipeline["mode"].upper() assert gpc.config.parallel.pipeline["mode"] in [ diff --git a/internlm/train/pipeline.py b/internlm/train/pipeline.py index 5907a4e3..537a852c 100644 --- a/internlm/train/pipeline.py +++ b/internlm/train/pipeline.py @@ -533,6 +533,7 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList], isp_communicato if ( zero_cfg.overlap_sync_grad and gpc.is_using_parallel_mode(ParallelMode.PIPELINE) + and gpc.config.parallel.pipeline.batch_p2p_comm is True and gpc.is_pipeline_first_stage() is False ): # When pipeline parallelism is enabled, we prefer to only enable optimizer