From 5f9b40b9b2184878566e0b424fa338b760c79775 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Date: Wed, 22 Jan 2025 19:35:56 -0500 Subject: [PATCH] Returning the use of the proper stream in allreduce (#382) --- vllm/distributed/parallel_state.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index e41e669571d81..d8017909bab4e 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -39,7 +39,8 @@ import vllm.envs as envs from vllm.distributed.utils import StatelessProcessGroup from vllm.logger import init_logger -from vllm.utils import direct_register_custom_op, supports_custom_op +from vllm.utils import (current_stream, direct_register_custom_op, + supports_custom_op) if TYPE_CHECKING: from vllm.config import VllmConfig @@ -365,7 +366,7 @@ def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor: return out pynccl_comm = self.pynccl_comm assert pynccl_comm is not None - out = pynccl_comm.all_reduce(input_) + out = pynccl_comm.all_reduce(input_, stream=current_stream()) if out is None: # fall back to the default all-reduce using PyTorch. # this usually happens during testing.