Skip to content

Commit

Permalink
remove incorrect type hint and dead imports
Browse files Browse the repository at this point in the history
  • Loading branch information
Quentin-Anthony committed Sep 9, 2024
1 parent 8b1cfa7 commit 63ac5b7
Showing 1 changed file with 3 additions and 6 deletions.
9 changes: 3 additions & 6 deletions megatron/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,11 @@
from types import GeneratorType
import torch.distributed as dist


import importlib
from typing import List, Dict, Any, Union
from typing import List, Dict, Any


def get_params_for_weight_decay_optimization(
module: Any, neox_args: Any
) -> List[Dict[str, Any]]:
def get_params_for_weight_decay_optimization(module: Any, neox_args: Any):
"""
Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and biases will have no weight decay but the rest will.
Expand Down Expand Up @@ -381,7 +378,7 @@ def reduce_weight_grads_from_model_parallel_region(input_):
input_ = input_.float()

# All-reduce.
torch.distributed.all_reduce(input_, group=mpu.get_model_parallel_group())
dist.all_reduce(input_, group=mpu.get_model_parallel_group())

# Bf16 convert
if dt == torch.bfloat16 and mpu.get_fp32_allreduce():
Expand Down

0 comments on commit 63ac5b7

Please sign in to comment.