diff --git a/megatron/mpu/mappings.py b/megatron/mpu/mappings.py index 1c5e6fb10..d95c5a3df 100644 --- a/megatron/mpu/mappings.py +++ b/megatron/mpu/mappings.py @@ -33,17 +33,17 @@ def _reduce(input_): if get_model_parallel_world_size() == 1: return input_ - # Bf16 convert + # upcast to fp32 if using fp32 allreduce dt = input_.dtype - if dt == torch.bfloat16 and get_fp32_allreduce(): + if get_fp32_allreduce(): input_ = input_.float() # All-reduce. torch.distributed.all_reduce(input_, group=get_model_parallel_group()) - # Bf16 convert - if dt == torch.bfloat16 and get_fp32_allreduce(): - input_ = input_.bfloat16() + # reconvert to original Bf16/Fp16 dtype + if get_fp32_allreduce(): + input_ = input_.to(dt) return input_ @@ -75,11 +75,6 @@ def _gather(input_): if world_size == 1: return input_ - # Bf16 convert - dt = input_.dtype - if dt == torch.bfloat16 and get_fp32_allreduce(): - input_ = input_.float() - # Size and dimension. last_dim = input_.dim() - 1 rank = get_model_parallel_rank() @@ -91,10 +86,6 @@ def _gather(input_): # Note: torch.cat already creates a contiguous tensor. output = torch.cat(tensor_list, dim=last_dim).contiguous() - # Bf16 convert - if dt == torch.bfloat16 and get_fp32_allreduce(): - output = output.bfloat16() - return output @@ -105,9 +96,9 @@ def _reduce_scatter_along_seq_dim(input_, seq_dim): if world_size == 1: return input_ - # Bf16 convert + # upcast to fp32 if using fp32 allreduce dt = input_.dtype - if dt == torch.bfloat16 and get_fp32_allreduce(): + if get_fp32_allreduce(): input_ = input_.float() assert input_.shape[seq_dim] % world_size == 0 @@ -117,11 +108,9 @@ def _reduce_scatter_along_seq_dim(input_, seq_dim): output = torch.empty_like(tensor_list[0]) torch.distributed.reduce_scatter(output, tensor_list) - # Bf16 convert - if dt == torch.bfloat16 and get_fp32_allreduce(): - output = ( - output.bfloat16() - ) + # reconvert to original Bf16/Fp16 dtype + if get_fp32_allreduce(): + input_ = input_.to(dt) return output @@ -134,11 +123,6 @@ def _gather_along_seq_dim(input_, seq_dim): if world_size == 1: return input_ - # Bf16 convert - dt = input_.dtype - if dt == torch.bfloat16 and get_fp32_allreduce(): - input_ = input_.float() - input_ = input_.contiguous() rank = get_model_parallel_rank() tensor_list = [torch.empty_like(input_) for _ in range(world_size)] @@ -146,12 +130,6 @@ def _gather_along_seq_dim(input_, seq_dim): torch.distributed.all_gather(tensor_list, input_, group=get_model_parallel_group()) output = torch.cat(tensor_list, dim=seq_dim) - # Bf16 convert - if dt == torch.bfloat16 and get_fp32_allreduce(): - output = ( - output.bfloat16() - ) - return output @@ -259,8 +237,7 @@ def backward(ctx, grad_output): class _GatherFromSequenceParallelRegion(torch.autograd.Function): - """All-Gather across sequence parallel region (same region as model parallel region.) - """ + """All-Gather across sequence parallel region (same region as model parallel region.)""" @staticmethod def symbolic(graph, input_, seq_dim): @@ -297,6 +274,7 @@ def backward(ctx, grad_output): None, ) + # ----------------- # Helper functions. # -----------------