From c2e4035b2c990894533ed4cf37856bbe139b5faa Mon Sep 17 00:00:00 2001 From: Myungjin Lee Date: Sat, 27 Jul 2024 11:56:16 -0700 Subject: [PATCH] refactor: nccl's async error handling in pytorch (#59) The default option to handle errors of nccl operations in PyTorch is to tear down the process. In a multiworld setting, we need to handle error gracefully by tearing down the world that encounter the errors. In order to do so, we set the environment variable, TORCH_NCCL_ASYNC_ERROR_HANDLING, with "2" (i.e., CleanupOnly) in the World Manager''s init function. In this way, users don't need to set the environment variable in their applications. Thus, in the examples, we removed the code snippet that sets the environment variable. --- examples/all_gather/m8d.py | 4 ---- examples/all_reduce/m8d.py | 4 ---- examples/broadcast/m8d.py | 4 ---- examples/reduce/m8d.py | 11 ++++++----- examples/resnet/m8d.py | 4 ---- examples/send_recv/m8d.py | 4 ---- multiworld/world_manager.py | 6 ++++++ 7 files changed, 12 insertions(+), 25 deletions(-) diff --git a/examples/all_gather/m8d.py b/examples/all_gather/m8d.py index e61c8a4..f984297 100644 --- a/examples/all_gather/m8d.py +++ b/examples/all_gather/m8d.py @@ -159,10 +159,6 @@ async def main(args): # for example: --worldinfo 1,0` means world with the index 1 will have a rank 0 parser.add_argument("--worldinfo", type=str, action="append") - # https://github.com/pytorch/pytorch/blob/main/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp#L114-L126 - # "2" is CleanUpOnly - os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "2" - args = parser.parse_args() loop = asyncio.get_event_loop() diff --git a/examples/all_reduce/m8d.py b/examples/all_reduce/m8d.py index 32b8bd1..0286d04 100644 --- a/examples/all_reduce/m8d.py +++ b/examples/all_reduce/m8d.py @@ -150,10 +150,6 @@ async def main(args): parser.add_argument("--addr", default="127.0.0.1") parser.add_argument("--worldinfo", type=str, action="append") - # https://github.com/pytorch/pytorch/blob/main/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp#L114-L126 - # "2" is CleanUpOnly - os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "2" - args = parser.parse_args() loop = asyncio.get_event_loop() diff --git a/examples/broadcast/m8d.py b/examples/broadcast/m8d.py index 497bbf3..efa7482 100644 --- a/examples/broadcast/m8d.py +++ b/examples/broadcast/m8d.py @@ -154,10 +154,6 @@ async def main(args): # for example: --worldinfo 1,0` means world with the index 1 will have a rank 0 parser.add_argument("--worldinfo", type=str, action="append") - # https://github.com/pytorch/pytorch/blob/main/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp#L114-L126 - # "2" is CleanUpOnly - os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "2" - args = parser.parse_args() loop = asyncio.get_event_loop() diff --git a/examples/reduce/m8d.py b/examples/reduce/m8d.py index 3d137c2..9b9c1d0 100644 --- a/examples/reduce/m8d.py +++ b/examples/reduce/m8d.py @@ -95,7 +95,12 @@ async def reduce(world_name, world_size, rank, backend): if dst == rank: print( - "Rank ", rank, " within world ", world_name, " has reduced tensor", tensor + "Rank ", + rank, + " within world ", + world_name, + " has reduced tensor", + tensor, ) print(f"done with step: {step}") @@ -154,10 +159,6 @@ async def main(args): parser.add_argument("--addr", default="127.0.0.1") parser.add_argument("--worldinfo", type=str, action="append") - # https://github.com/pytorch/pytorch/blob/main/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp#L114-L126 - # "2" is CleanUpOnly - os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "2" - args = parser.parse_args() loop = asyncio.get_event_loop() diff --git a/examples/resnet/m8d.py b/examples/resnet/m8d.py index e96a486..b9c63c8 100644 --- a/examples/resnet/m8d.py +++ b/examples/resnet/m8d.py @@ -384,10 +384,6 @@ async def multi_host(args): "--multihost", action=argparse.BooleanOptionalAction, default=False ) - # https://github.com/pytorch/pytorch/blob/main/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp#L114-L126 - # "2" is CleanUpOnly - os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "2" - args = parser.parse_args() atexit.register(cleanup) diff --git a/examples/send_recv/m8d.py b/examples/send_recv/m8d.py index 71e8beb..1c21cd6 100644 --- a/examples/send_recv/m8d.py +++ b/examples/send_recv/m8d.py @@ -191,10 +191,6 @@ async def main(args): parser.add_argument("--addr", default="127.0.0.1") parser.add_argument("--worldinfo", type=str, action="append") - # https://github.com/pytorch/pytorch/blob/main/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp#L114-L126 - # "2" is CleanUpOnly - os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "2" - args = parser.parse_args() loop = asyncio.get_event_loop() diff --git a/multiworld/world_manager.py b/multiworld/world_manager.py index 834852e..b1824fd 100644 --- a/multiworld/world_manager.py +++ b/multiworld/world_manager.py @@ -38,6 +38,12 @@ class WorldManager: def __init__(self, enable_monitor=True): """Initialize a world manager.""" + # https://github.com/pytorch/pytorch/blob/v2.4.0/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp#L118-L130 + # "2" is CleanUpOnly + # We use CleanupOnly in order to allow error handling at user process + # level without tearing down the process. + os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "2" + self._worlds_stores: dict[str, dist.TCPStore] = dict() self._communicator = WorldCommunicator(self) self._current_world = ""