Skip to content

Commit

Permalink
refactor: nccl's async error handling in pytorch (#59)
Browse files Browse the repository at this point in the history
The default option to handle errors of nccl operations in PyTorch is to tear
down the process. In a multiworld setting, we need to handle error gracefully
by tearing down the world that encounter the errors. In order to do so, we set
the environment variable, TORCH_NCCL_ASYNC_ERROR_HANDLING, with "2"
(i.e., CleanupOnly) in the World Manager''s init function.
In this way, users don't need to set the environment variable in their
applications. Thus, in the examples, we removed the code snippet that sets
the environment variable.
  • Loading branch information
myungjin authored Jul 27, 2024
1 parent b46e186 commit c2e4035
Show file tree
Hide file tree
Showing 7 changed files with 12 additions and 25 deletions.
4 changes: 0 additions & 4 deletions examples/all_gather/m8d.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,6 @@ async def main(args):
# for example: --worldinfo 1,0` means world with the index 1 will have a rank 0
parser.add_argument("--worldinfo", type=str, action="append")

# https://github.com/pytorch/pytorch/blob/main/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp#L114-L126
# "2" is CleanUpOnly
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "2"

args = parser.parse_args()

loop = asyncio.get_event_loop()
Expand Down
4 changes: 0 additions & 4 deletions examples/all_reduce/m8d.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,6 @@ async def main(args):
parser.add_argument("--addr", default="127.0.0.1")
parser.add_argument("--worldinfo", type=str, action="append")

# https://github.com/pytorch/pytorch/blob/main/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp#L114-L126
# "2" is CleanUpOnly
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "2"

args = parser.parse_args()

loop = asyncio.get_event_loop()
Expand Down
4 changes: 0 additions & 4 deletions examples/broadcast/m8d.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,6 @@ async def main(args):
# for example: --worldinfo 1,0` means world with the index 1 will have a rank 0
parser.add_argument("--worldinfo", type=str, action="append")

# https://github.com/pytorch/pytorch/blob/main/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp#L114-L126
# "2" is CleanUpOnly
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "2"

args = parser.parse_args()

loop = asyncio.get_event_loop()
Expand Down
11 changes: 6 additions & 5 deletions examples/reduce/m8d.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,12 @@ async def reduce(world_name, world_size, rank, backend):

if dst == rank:
print(
"Rank ", rank, " within world ", world_name, " has reduced tensor", tensor
"Rank ",
rank,
" within world ",
world_name,
" has reduced tensor",
tensor,
)

print(f"done with step: {step}")
Expand Down Expand Up @@ -154,10 +159,6 @@ async def main(args):
parser.add_argument("--addr", default="127.0.0.1")
parser.add_argument("--worldinfo", type=str, action="append")

# https://github.com/pytorch/pytorch/blob/main/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp#L114-L126
# "2" is CleanUpOnly
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "2"

args = parser.parse_args()

loop = asyncio.get_event_loop()
Expand Down
4 changes: 0 additions & 4 deletions examples/resnet/m8d.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,10 +384,6 @@ async def multi_host(args):
"--multihost", action=argparse.BooleanOptionalAction, default=False
)

# https://github.com/pytorch/pytorch/blob/main/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp#L114-L126
# "2" is CleanUpOnly
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "2"

args = parser.parse_args()
atexit.register(cleanup)

Expand Down
4 changes: 0 additions & 4 deletions examples/send_recv/m8d.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,6 @@ async def main(args):
parser.add_argument("--addr", default="127.0.0.1")
parser.add_argument("--worldinfo", type=str, action="append")

# https://github.com/pytorch/pytorch/blob/main/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp#L114-L126
# "2" is CleanUpOnly
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "2"

args = parser.parse_args()

loop = asyncio.get_event_loop()
Expand Down
6 changes: 6 additions & 0 deletions multiworld/world_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ class WorldManager:

def __init__(self, enable_monitor=True):
"""Initialize a world manager."""
# https://github.com/pytorch/pytorch/blob/v2.4.0/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp#L118-L130
# "2" is CleanUpOnly
# We use CleanupOnly in order to allow error handling at user process
# level without tearing down the process.
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "2"

self._worlds_stores: dict[str, dist.TCPStore] = dict()
self._communicator = WorldCommunicator(self)
self._current_world = ""
Expand Down

0 comments on commit c2e4035

Please sign in to comment.