diff --git a/megatron/core/dist_checkpointing/strategies/torch.py b/megatron/core/dist_checkpointing/strategies/torch.py index 8a7a965b1b..463ae11618 100644 --- a/megatron/core/dist_checkpointing/strategies/torch.py +++ b/megatron/core/dist_checkpointing/strategies/torch.py @@ -14,7 +14,7 @@ import numpy as np import torch -from pkg_resources import packaging +from pkg_resources.extern import packaging from torch.distributed import checkpoint from torch.distributed._shard._utils import narrow_tensor_by_index from torch.distributed._shard.metadata import ShardMetadata diff --git a/megatron/core/models/retro/config.py b/megatron/core/models/retro/config.py index 3e3d0b538a..c70c5813b7 100644 --- a/megatron/core/models/retro/config.py +++ b/megatron/core/models/retro/config.py @@ -7,7 +7,7 @@ from dataclasses import dataclass from importlib.metadata import version -from pkg_resources import packaging +from pkg_resources.extern import packaging from megatron.core.transformer import TransformerConfig diff --git a/megatron/core/tensor_parallel/random.py b/megatron/core/tensor_parallel/random.py index ee074df990..b76c591d1d 100644 --- a/megatron/core/tensor_parallel/random.py +++ b/megatron/core/tensor_parallel/random.py @@ -8,7 +8,7 @@ from importlib.metadata import version import torch -from pkg_resources import packaging +from pkg_resources.extern import packaging from torch import _C from torch.cuda import _lazy_call from torch.cuda import device as device_ctx_manager