From 1395ff41b3c9b700e16c28db08175a2597c2a007 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Thu, 23 Jan 2025 00:25:10 +0000 Subject: [PATCH] Returning the use of the proper stream in allreduce --- vllm/distributed/parallel_state.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index e41e669571d81..d8017909bab4e 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -39,7 +39,8 @@ import vllm.envs as envs from vllm.distributed.utils import StatelessProcessGroup from vllm.logger import init_logger -from vllm.utils import direct_register_custom_op, supports_custom_op +from vllm.utils import (current_stream, direct_register_custom_op, + supports_custom_op) if TYPE_CHECKING: from vllm.config import VllmConfig @@ -365,7 +366,7 @@ def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor: return out pynccl_comm = self.pynccl_comm assert pynccl_comm is not None - out = pynccl_comm.all_reduce(input_) + out = pynccl_comm.all_reduce(input_, stream=current_stream()) if out is None: # fall back to the default all-reduce using PyTorch. # this usually happens during testing.