Skip to content

Commit

Permalink
[Quality] Avoid torch.distributed imports at root
Browse files Browse the repository at this point in the history
ghstack-source-id: f773ed94d4b7ca13c603f32251f0735c751ebf94
Pull Request resolved: #1134
  • Loading branch information
vmoens committed Dec 9, 2024
1 parent 22da679 commit 0e9a854
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 17 deletions.
9 changes: 4 additions & 5 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@

import orjson as json
import torch
import torch.distributed as dist

from tensordict.memmap import MemoryMappedTensor

Expand Down Expand Up @@ -2388,7 +2387,7 @@ def _send(
dst: int,
_tag: int = -1,
pseudo_rand: bool = False,
group: "dist.ProcessGroup" | None = None,
group: "torch.distributed.ProcessGroup" | None = None,
) -> int:
for td in self.tensordicts:
_tag = td._send(dst, _tag=_tag, pseudo_rand=pseudo_rand, group=group)
Expand All @@ -2400,7 +2399,7 @@ def _isend(
_tag: int = -1,
_futures: list[torch.Future] | None = None,
pseudo_rand: bool = False,
group: "dist.ProcessGroup" | None = None,
group: "torch.distributed.ProcessGroup" | None = None,
) -> int:
if _futures is None:
is_root = True
Expand All @@ -2421,7 +2420,7 @@ def _recv(
src: int,
_tag: int = -1,
pseudo_rand: bool = False,
group: "dist.ProcessGroup" | None = None,
group: "torch.distributed.ProcessGroup" | None = None,
) -> int:
for td in self.tensordicts:
_tag = td._recv(src, _tag=_tag, pseudo_rand=pseudo_rand, group=group)
Expand All @@ -2434,7 +2433,7 @@ def _irecv(
_tag: int = -1,
_future_list: list[torch.Future] = None,
pseudo_rand: bool = False,
group: "dist.ProcessGroup" | None = None,
group: "torch.distributed.ProcessGroup" | None = None,
) -> tuple[int, list[torch.Future]] | list[torch.Future] | None:
root = False
if _future_list is None:
Expand Down
34 changes: 24 additions & 10 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@
unravel_key,
unravel_key_list,
)
from torch import distributed as dist, multiprocessing as mp, nn, Tensor
from torch import multiprocessing as mp, nn, Tensor
from torch.nn.parameter import Parameter, UninitializedTensorMixin
from torch.utils._pytree import tree_map

Expand Down Expand Up @@ -7260,7 +7260,7 @@ def del_(self, key: NestedKey) -> T:

# Distributed functionality
def gather_and_stack(
self, dst: int, group: "dist.ProcessGroup" | None = None
self, dst: int, group: "torch.distributed.ProcessGroup" | None = None
) -> T | None:
"""Gathers tensordicts from various workers and stacks them onto self in the destination worker.
Expand Down Expand Up @@ -7319,6 +7319,8 @@ def gather_and_stack(
... main_worker.join()
... secondary_worker.join()
"""
from torch import distributed as dist

output = (
[None for _ in range(dist.get_world_size(group=group))]
if dst == dist.get_rank(group=group)
Expand All @@ -7336,7 +7338,7 @@ def send(
self,
dst: int,
*,
group: "dist.ProcessGroup" | None = None,
group: "torch.distributed.ProcessGroup" | None = None,
init_tag: int = 0,
pseudo_rand: bool = False,
) -> None: # noqa: D417
Expand Down Expand Up @@ -7426,8 +7428,10 @@ def _send(
dst: int,
_tag: int = -1,
pseudo_rand: bool = False,
group: "dist.ProcessGroup" | None = None,
group: "torch.distributed.ProcessGroup" | None = None,
) -> int:
from torch import distributed as dist

for key in self.sorted_keys:
value = self._get_str(key, NO_DEFAULT)
if isinstance(value, Tensor):
Expand All @@ -7449,7 +7453,7 @@ def recv(
self,
src: int,
*,
group: "dist.ProcessGroup" | None = None,
group: "torch.distributed.ProcessGroup" | None = None,
init_tag: int = 0,
pseudo_rand: bool = False,
) -> int: # noqa: D417
Expand Down Expand Up @@ -7481,9 +7485,11 @@ def _recv(
src: int,
_tag: int = -1,
pseudo_rand: bool = False,
group: "dist.ProcessGroup" | None = None,
group: "torch.distributed.ProcessGroup" | None = None,
non_blocking: bool = False,
) -> int:
from torch import distributed as dist

for key in self.sorted_keys:
value = self._get_str(key, NO_DEFAULT)
if isinstance(value, Tensor):
Expand All @@ -7508,7 +7514,7 @@ def isend(
self,
dst: int,
*,
group: "dist.ProcessGroup" | None = None,
group: "torch.distributed.ProcessGroup" | None = None,
init_tag: int = 0,
pseudo_rand: bool = False,
) -> int: # noqa: D417
Expand Down Expand Up @@ -7603,8 +7609,10 @@ def _isend(
_tag: int = -1,
_futures: list[torch.Future] | None = None,
pseudo_rand: bool = False,
group: "dist.ProcessGroup" | None = None,
group: "torch.distributed.ProcessGroup" | None = None,
) -> int:
from torch import distributed as dist

root = False
if _futures is None:
root = True
Expand Down Expand Up @@ -7639,7 +7647,7 @@ def irecv(
self,
src: int,
*,
group: "dist.ProcessGroup" | None = None,
group: "torch.distributed.ProcessGroup" | None = None,
return_premature: bool = False,
init_tag: int = 0,
pseudo_rand: bool = False,
Expand Down Expand Up @@ -7687,8 +7695,10 @@ def _irecv(
_tag: int = -1,
_future_list: list[torch.Future] = None,
pseudo_rand: bool = False,
group: "dist.ProcessGroup" | None = None,
group: "torch.distributed.ProcessGroup" | None = None,
) -> tuple[int, list[torch.Future]] | list[torch.Future] | None:
from torch import distributed as dist

root = False
if _future_list is None:
_future_list = []
Expand Down Expand Up @@ -7736,6 +7746,8 @@ def reduce(
Only the process with ``rank`` dst is going to receive the final result.
"""
from torch import distributed as dist

if op is None:
op = dist.ReduceOp.SUM
return self._reduce(dst, op, async_op, return_premature, group=group)
Expand All @@ -7749,6 +7761,8 @@ def _reduce(
_future_list=None,
group=None,
):
from torch import distributed as dist

if op is None:
op = dist.ReduceOp.SUM
root = False
Expand Down
4 changes: 2 additions & 2 deletions tensordict/tensorclass.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ from tensordict.utils import (
unravel_key as unravel_key,
unravel_key_list as unravel_key_list,
)
from torch import distributed as dist, multiprocessing as mp, nn, Tensor
from torch import multiprocessing as mp, nn, Tensor

class _NoDefault(enum.IntEnum):
ZERO = 0
Expand Down Expand Up @@ -663,7 +663,7 @@ class TensorClass:
) -> T: ...
def del_(self, key: NestedKey) -> T: ...
def gather_and_stack(
self, dst: int, group: dist.ProcessGroup | None = None
self, dst: int, group: "dist.ProcessGroup" | None = None
) -> T | None: ...
def send(
self,
Expand Down
9 changes: 9 additions & 0 deletions test/smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import sys

import pytest

Expand All @@ -11,6 +12,14 @@ def test_imports():
from tensordict import TensorDict # noqa: F401
from tensordict.nn import TensorDictModule # noqa: F401

# # Check that distributed is not imported
# v = set(sys.modules.values())
# try:
# from torch import distributed
# except ImportError:
# return
# assert distributed not in v


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
Expand Down

0 comments on commit 0e9a854

Please sign in to comment.