Skip to content

Commit

Permalink
Add comment
Browse files Browse the repository at this point in the history
  • Loading branch information
bclyang committed Aug 22, 2024
1 parent 8e7400f commit 53d0ae8
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions megatron/mpu/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def _reduce_scatter_along_seq_dim(input_, seq_dim):
assert dim_size[seq_dim] % world_size == 0

if seq_dim == 0:
# reduce_scatter_tensor is faster but only works correctly on dimension 0
dim_size[seq_dim] = dim_size[seq_dim] // world_size
output = torch.empty(
dim_size, dtype=input_.dtype, device=torch.cuda.current_device()
Expand Down Expand Up @@ -144,6 +145,7 @@ def _gather_along_seq_dim(input_, seq_dim):
dim_size[seq_dim] = dim_size[seq_dim] * world_size

if seq_dim == 0:
# reduce_gather_tensor is faster but only works correctly on dimension 0
output = torch.empty(
dim_size, dtype=input_.dtype, device=torch.cuda.current_device()
)
Expand Down

0 comments on commit 53d0ae8

Please sign in to comment.