From 3ec3ceeda47420d3c26f3333c9de5e8efc8368fc Mon Sep 17 00:00:00 2001 From: cennn <2523403608@qq.com> Date: Sun, 5 Jan 2025 16:29:02 +0800 Subject: [PATCH 1/3] pynccl rm self.stream --- tests/distributed/test_pynccl.py | 5 ++-- .../device_communicators/pynccl.py | 28 ++++++------------- 2 files changed, 11 insertions(+), 22 deletions(-) diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index 36cfe42251384..573e5a20b411c 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -137,9 +137,8 @@ def worker_fn_with_cudagraph(): # run something in the default stream to initialize torch engine a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}') torch.cuda.synchronize() - with torch.cuda.graph( - graph, stream=pynccl_comm.stream), pynccl_comm.change_state( - enable=True): + with torch.cuda.graph(graph, stream=torch.cuda.current_stream()), \ + pynccl_comm.change_state(enable=True): a_out = pynccl_comm.all_reduce(a) torch.cuda.synchronize() graph.replay() diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index a6800f93f167b..93d96fd8f5686 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -51,7 +51,6 @@ def __init__( if self.world_size == 1: self.available = False self.disabled = True - self.stream = None return try: self.nccl = NCCLLibrary(library_path) @@ -60,7 +59,6 @@ def __init__( # e.g. in a non-GPU environment self.available = False self.disabled = True - self.stream = None return self.available = True @@ -98,12 +96,12 @@ def __init__( with torch.cuda.device(device): self.comm: ncclComm_t = self.nccl.ncclCommInitRank( self.world_size, self.unique_id, self.rank) - self.stream = torch.cuda.Stream() + stream = torch.cuda.current_stream() # A small all_reduce for warmup. data = torch.zeros(1, device=device) self.all_reduce(data) - self.stream.synchronize() + stream.synchronize() del data def all_reduce(self, @@ -122,7 +120,7 @@ def all_reduce(self, out_tensor = torch.empty_like(in_tensor) if stream is None: - stream = self.stream + stream = torch.cuda.current_stream() self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()), buffer_type(out_tensor.data_ptr()), in_tensor.numel(), @@ -144,7 +142,7 @@ def all_gather(self, f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {input_tensor.device}") if stream is None: - stream = self.stream + stream = torch.cuda.current_stream() self.nccl.ncclAllGather( buffer_type(input_tensor.data_ptr()), buffer_type(output_tensor.data_ptr()), input_tensor.numel(), @@ -165,7 +163,7 @@ def reduce_scatter(self, f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {input_tensor.device}") if stream is None: - stream = self.stream + stream = torch.cuda.current_stream() self.nccl.ncclReduceScatter( buffer_type(input_tensor.data_ptr()), buffer_type(output_tensor.data_ptr()), output_tensor.numel(), @@ -180,7 +178,7 @@ def send(self, tensor: torch.Tensor, dst: int, stream=None): f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {tensor.device}") if stream is None: - stream = self.stream + stream = torch.cuda.current_stream() self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(), ncclDataTypeEnum.from_torch(tensor.dtype), dst, self.comm, cudaStream_t(stream.cuda_stream)) @@ -192,7 +190,7 @@ def recv(self, tensor: torch.Tensor, src: int, stream=None): f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {tensor.device}") if stream is None: - stream = self.stream + stream = torch.cuda.current_stream() self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(), ncclDataTypeEnum.from_torch(tensor.dtype), src, self.comm, cudaStream_t(stream.cuda_stream)) @@ -204,7 +202,7 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None): f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {tensor.device}") if stream is None: - stream = self.stream + stream = torch.cuda.current_stream() if src == self.rank: sendbuff = buffer_type(tensor.data_ptr()) # NCCL requires the sender also to have a receive buffer @@ -217,9 +215,7 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None): self.comm, cudaStream_t(stream.cuda_stream)) @contextmanager - def change_state(self, - enable: Optional[bool] = None, - stream: Optional[torch.cuda.Stream] = None): + def change_state(self, enable: Optional[bool] = None): """ A context manager to change the state of the communicator. """ @@ -227,15 +223,9 @@ def change_state(self, # guess a default value when not specified enable = self.available - if stream is None: - stream = self.stream - old_disable = self.disabled - old_stream = self.stream - self.stream = stream self.disabled = not enable yield self.disabled = old_disable - self.stream = old_stream From 55fe0769858a2a51b5baed10c0d82d4e7de3b127 Mon Sep 17 00:00:00 2001 From: cennn <2523403608@qq.com> Date: Sun, 5 Jan 2025 16:39:04 +0800 Subject: [PATCH 2/3] maybe_pynccl_context = pynccl_comm.change_state() --- vllm/distributed/parallel_state.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index a0d4235460f3b..dccd3addbcb35 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -310,8 +310,7 @@ def graph_capture( if not pynccl_comm: maybe_pynccl_context = nullcontext() else: - maybe_pynccl_context = pynccl_comm.change_state( - stream=torch.cuda.current_stream()) + maybe_pynccl_context = pynccl_comm.change_state() with maybe_pynccl_context: yield graph_capture_context From a11a6de4b656ad6a36d816b1d0f805f896bfcdb4 Mon Sep 17 00:00:00 2001 From: cennn <2523403608@qq.com> Date: Sun, 5 Jan 2025 17:33:32 +0800 Subject: [PATCH 3/3] fix pynccl worker_fn_with_cudagraph --- tests/distributed/test_pynccl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index 573e5a20b411c..a77b48d5e49f3 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -137,7 +137,7 @@ def worker_fn_with_cudagraph(): # run something in the default stream to initialize torch engine a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}') torch.cuda.synchronize() - with torch.cuda.graph(graph, stream=torch.cuda.current_stream()), \ + with torch.cuda.graph(graph), \ pynccl_comm.change_state(enable=True): a_out = pynccl_comm.all_reduce(a) torch.cuda.synchronize()