Skip to content

Commit

Permalink
Returning the use of the proper stream in allreduce (#382)
Browse files Browse the repository at this point in the history
  • Loading branch information
gshtras authored Jan 23, 2025
1 parent a600e9f commit 5f9b40b
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@
import vllm.envs as envs
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger
from vllm.utils import direct_register_custom_op, supports_custom_op
from vllm.utils import (current_stream, direct_register_custom_op,
supports_custom_op)

if TYPE_CHECKING:
from vllm.config import VllmConfig
Expand Down Expand Up @@ -365,7 +366,7 @@ def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
return out
pynccl_comm = self.pynccl_comm
assert pynccl_comm is not None
out = pynccl_comm.all_reduce(input_)
out = pynccl_comm.all_reduce(input_, stream=current_stream())
if out is None:
# fall back to the default all-reduce using PyTorch.
# this usually happens during testing.
Expand Down

0 comments on commit 5f9b40b

Please sign in to comment.