Skip to content

Commit

Permalink
feat(pipeline): add non-p2p-comm support
Browse files Browse the repository at this point in the history
  • Loading branch information
mwiacx committed Dec 4, 2024
1 parent 4a6b453 commit 4aee67c
Show file tree
Hide file tree
Showing 7 changed files with 158 additions and 83 deletions.
1 change: 1 addition & 0 deletions configs/_base_/models/internlm2_1B.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
1. size: int, the size of pipeline parallel.
2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler,
defaults to False.
4. batch_p2p_comm: bool, enable/disable batch p2p communication, defaults to False.
weight parallel (dict):
1. size: int, the size of weight parallel.
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
Expand Down
104 changes: 74 additions & 30 deletions internlm/core/scheduler/comm/p2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,15 @@ def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) ->
return tensor_chunk_shape, chunk_tensor


def _p2p_func(_comm_op, _obj, _comm_rank):
if gpc.config.parallel.pipeline.batch_p2p_comm is True:
op_or_handle = dist.P2POp(_comm_op, _obj, _comm_rank)
else:
op_or_handle = _comm_op(_obj, _comm_rank)

return op_or_handle


def create_recv_buffer_with_shapes(recv_shapes, dtype, scatter_gather_tensors):
if isinstance(recv_shapes, torch.Size):
recv_chunk_shape, recv_split = _get_tensor_shape(recv_shapes, scatter_gather_tensors)
Expand Down Expand Up @@ -78,12 +87,10 @@ def process_object_to_send(object_send, scatter_gather_tensors):

def filling_ops_queue(obj, comm_op, comm_rank, ops_queue):
if isinstance(obj, torch.Tensor):
op_to_add = dist.P2POp(comm_op, obj, comm_rank)
ops_queue.append(op_to_add)
ops_queue.append(_p2p_func(comm_op, obj, comm_rank))
else:
for tensor_to_comm in obj:
op_to_add = dist.P2POp(comm_op, tensor_to_comm, comm_rank)
ops_queue.append(op_to_add)
ops_queue.append(_p2p_func(comm_op, tensor_to_comm, comm_rank))


def _communicate(
Expand Down Expand Up @@ -156,23 +163,42 @@ def _communicate(
object_send_next = process_object_to_send(object_send_next, scatter_gather_tensors)

ops = []
if object_send_prev is not None:
filling_ops_queue(object_send_prev, dist.isend, prev_rank, ops)

if tensor_recv_prev is not None:
filling_ops_queue(tensor_recv_prev, dist.irecv, prev_rank, ops)
if gpc.get_local_rank(ParallelMode.PIPELINE) % 2 == 0:
if object_send_next is not None:
filling_ops_queue(object_send_next, dist.isend, next_rank, ops)

if tensor_recv_next is not None:
filling_ops_queue(tensor_recv_next, dist.irecv, next_rank, ops)
if tensor_recv_prev is not None:
filling_ops_queue(tensor_recv_prev, dist.irecv, prev_rank, ops)

if object_send_prev is not None:
filling_ops_queue(object_send_prev, dist.isend, prev_rank, ops)

if tensor_recv_next is not None:
filling_ops_queue(tensor_recv_next, dist.irecv, next_rank, ops)
else:
if tensor_recv_prev is not None:
filling_ops_queue(tensor_recv_prev, dist.irecv, prev_rank, ops)

if object_send_next is not None:
filling_ops_queue(object_send_next, dist.isend, next_rank, ops)

if tensor_recv_next is not None:
filling_ops_queue(tensor_recv_next, dist.irecv, next_rank, ops)

if object_send_prev is not None:
filling_ops_queue(object_send_prev, dist.isend, prev_rank, ops)

if object_send_next is not None:
filling_ops_queue(object_send_next, dist.isend, next_rank, ops)
if len(ops) > 0:
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
internlm_accelerator.synchronize()
if gpc.config.parallel.pipeline.batch_p2p_comm is True:
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
internlm_accelerator.synchronize()
else:
for req in ops:
req.wait()

if recv_prev and recv_prev_split:
if isinstance(tensor_recv_prev, torch.Tensor):
Expand Down Expand Up @@ -265,29 +291,47 @@ def _communicate_async(
object_send_next = process_object_to_send(object_send_next, scatter_gather_tensors)

ops = []
if object_send_prev is not None:
filling_ops_queue(object_send_prev, dist.isend, prev_rank, ops)

if tensor_recv_prev is not None:
filling_ops_queue(tensor_recv_prev, dist.irecv, prev_rank, ops)
if gpc.get_local_rank(ParallelMode.PIPELINE) % 2 == 0:
if object_send_next is not None:
filling_ops_queue(object_send_next, dist.isend, next_rank, ops)

if tensor_recv_next is not None:
filling_ops_queue(tensor_recv_next, dist.irecv, next_rank, ops)
if tensor_recv_prev is not None:
filling_ops_queue(tensor_recv_prev, dist.irecv, prev_rank, ops)

if object_send_next is not None:
filling_ops_queue(object_send_next, dist.isend, next_rank, ops)
if object_send_prev is not None:
filling_ops_queue(object_send_prev, dist.isend, prev_rank, ops)

if len(ops) > 0:
if tensor_recv_next is not None:
filling_ops_queue(tensor_recv_next, dist.irecv, next_rank, ops)
else:
if tensor_recv_prev is not None:
filling_ops_queue(tensor_recv_prev, dist.irecv, prev_rank, ops)

if object_send_next is not None:
filling_ops_queue(object_send_next, dist.isend, next_rank, ops)

if tensor_recv_next is not None:
filling_ops_queue(tensor_recv_next, dist.irecv, next_rank, ops)

if object_send_prev is not None:
filling_ops_queue(object_send_prev, dist.isend, prev_rank, ops)

if len(ops) > 0 and gpc.config.parallel.pipeline.batch_p2p_comm is True:
reqs = dist.batch_isend_irecv(ops)

# return and do other things
yield

if len(ops) > 0:
for req in reqs: # pylint: disable=E0601
req.wait()
# To protect against race condition when using batch_isend_irecv().
internlm_accelerator.synchronize()
if gpc.config.parallel.pipeline.batch_p2p_comm is True:
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
internlm_accelerator.synchronize()
else:
for req in ops:
req.wait()

if recv_prev and recv_prev_split:
if isinstance(tensor_recv_prev, torch.Tensor):
Expand Down
2 changes: 1 addition & 1 deletion internlm/core/scheduler/no_pipeline_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def forward_backward_step(
outputs, labels = None, None

# Compatible for non-moe
if hasattr(gpc.config.model, "num_experts"):
if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1:
return outputs, labels, loss, moe_loss
else:
return outputs, labels, loss
Loading

0 comments on commit 4aee67c

Please sign in to comment.