Skip to content

Commit

Permalink
Returning the use of the proper stream in allreduce
Browse files Browse the repository at this point in the history
  • Loading branch information
gshtras committed Jan 23, 2025
1 parent b5839a1 commit 1395ff4
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@
import vllm.envs as envs
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger
from vllm.utils import direct_register_custom_op, supports_custom_op
from vllm.utils import (current_stream, direct_register_custom_op,
supports_custom_op)

if TYPE_CHECKING:
from vllm.config import VllmConfig
Expand Down Expand Up @@ -365,7 +366,7 @@ def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
return out
pynccl_comm = self.pynccl_comm
assert pynccl_comm is not None
out = pynccl_comm.all_reduce(input_)
out = pynccl_comm.all_reduce(input_, stream=current_stream())
if out is None:
# fall back to the default all-reduce using PyTorch.
# this usually happens during testing.
Expand Down

0 comments on commit 1395ff4

Please sign in to comment.