Skip to content

Commit

Permalink
update fp32_allreduce to handle fp16 ; don't cast to fp32 for gathers
Browse files Browse the repository at this point in the history
  • Loading branch information
haileyschoelkopf committed Aug 19, 2024
1 parent 9ce982e commit ab11a6a
Showing 1 changed file with 12 additions and 34 deletions.
46 changes: 12 additions & 34 deletions megatron/mpu/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,17 @@ def _reduce(input_):
if get_model_parallel_world_size() == 1:
return input_

# Bf16 convert
# upcast to fp32 if using fp32 allreduce
dt = input_.dtype
if dt == torch.bfloat16 and get_fp32_allreduce():
if get_fp32_allreduce():
input_ = input_.float()

# All-reduce.
torch.distributed.all_reduce(input_, group=get_model_parallel_group())

# Bf16 convert
if dt == torch.bfloat16 and get_fp32_allreduce():
input_ = input_.bfloat16()
# reconvert to original Bf16/Fp16 dtype
if get_fp32_allreduce():
input_ = input_.to(dt)

return input_

Expand Down Expand Up @@ -75,11 +75,6 @@ def _gather(input_):
if world_size == 1:
return input_

# Bf16 convert
dt = input_.dtype
if dt == torch.bfloat16 and get_fp32_allreduce():
input_ = input_.float()

# Size and dimension.
last_dim = input_.dim() - 1
rank = get_model_parallel_rank()
Expand All @@ -91,10 +86,6 @@ def _gather(input_):
# Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=last_dim).contiguous()

# Bf16 convert
if dt == torch.bfloat16 and get_fp32_allreduce():
output = output.bfloat16()

return output


Expand All @@ -105,9 +96,9 @@ def _reduce_scatter_along_seq_dim(input_, seq_dim):
if world_size == 1:
return input_

# Bf16 convert
# upcast to fp32 if using fp32 allreduce
dt = input_.dtype
if dt == torch.bfloat16 and get_fp32_allreduce():
if get_fp32_allreduce():
input_ = input_.float()

assert input_.shape[seq_dim] % world_size == 0
Expand All @@ -117,11 +108,9 @@ def _reduce_scatter_along_seq_dim(input_, seq_dim):
output = torch.empty_like(tensor_list[0])
torch.distributed.reduce_scatter(output, tensor_list)

# Bf16 convert
if dt == torch.bfloat16 and get_fp32_allreduce():
output = (
output.bfloat16()
)
# reconvert to original Bf16/Fp16 dtype
if get_fp32_allreduce():
input_ = input_.to(dt)

return output

Expand All @@ -134,24 +123,13 @@ def _gather_along_seq_dim(input_, seq_dim):
if world_size == 1:
return input_

# Bf16 convert
dt = input_.dtype
if dt == torch.bfloat16 and get_fp32_allreduce():
input_ = input_.float()

input_ = input_.contiguous()
rank = get_model_parallel_rank()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=get_model_parallel_group())
output = torch.cat(tensor_list, dim=seq_dim)

# Bf16 convert
if dt == torch.bfloat16 and get_fp32_allreduce():
output = (
output.bfloat16()
)

return output


Expand Down Expand Up @@ -259,8 +237,7 @@ def backward(ctx, grad_output):


class _GatherFromSequenceParallelRegion(torch.autograd.Function):
"""All-Gather across sequence parallel region (same region as model parallel region.)
"""
"""All-Gather across sequence parallel region (same region as model parallel region.)"""

@staticmethod
def symbolic(graph, input_, seq_dim):
Expand Down Expand Up @@ -297,6 +274,7 @@ def backward(ctx, grad_output):
None,
)


# -----------------
# Helper functions.
# -----------------
Expand Down

0 comments on commit ab11a6a

Please sign in to comment.