-
-
Notifications
You must be signed in to change notification settings - Fork 5.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TPU] Support collective communications in XLA devices #6813
Conversation
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge). To run full CI, you can do one of these:
🚀 |
This PR was part of #5871, but separated out to get a quick review. |
vllm/distributed/parallel_state.py
Outdated
pynccl_comm: Optional[Any] # PyNccl communicator | ||
ca_comm: Optional[Any] # Custom allreduce communicator | ||
mq_broadcaster: Optional[Any] # shared memory broadcaster | ||
use_xla: bool # Whether to use PyTorch XLA communicator |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does tpu platform support NCCL? if not, creating these communicators might lead to error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TPU doesn't support NCCL, but I didn't see any error with the other communicators.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The TPU backend uses gloo backend in addition to the distributed runtime in xm
. Maybe that's the reason.
vllm/distributed/parallel_state.py
Outdated
@@ -125,6 +129,7 @@ class GroupCoordinator: | |||
pynccl_comm: Optional[Any] # PyNccl communicator | |||
ca_comm: Optional[Any] # Custom allreduce communicator | |||
mq_broadcaster: Optional[Any] # shared memory broadcaster | |||
use_xla: bool # Whether to use PyTorch XLA communicator |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use_xxx
is a initialization parameter, and we usually hold communicator
inside group coordinator.
Can you add a tpu_communicator
under https://github.com/vllm-project/vllm/tree/main/vllm/distributed/device_communicators ?
One additional benefit, is that you can implement the gather
logic to allgather
, without intrusive change to logits_processor.py
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One additional benefit, is that you can implement the gather logic to allgather, without intrusive change to logits_processor.py .
This is actually not the case because the TPU backend explicitly requires all-gather, which means each device's output should not be None. If we implement gather by using all-gather
and outputting None for non-root ranks, XLA will raise an error.
return | ||
self.disabled = False | ||
|
||
pjrt.initialize_multiprocess(local_rank, world_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can get rank and world size from the group
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for letting me know! Updated the PR. PTAL.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you need to remove the code inside tpu worker? I don't know if pjrt and xm support initialization for multiple times.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@youkaichao Good point. That's actually updated in #5871. In the current main branch, there's no code initializing the XLA's distributed runtime.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for addressing my comments!
@youkaichao Thanks for your review and suggestions to the PR! |
…6813) Signed-off-by: Alvant <[email protected]>
This PR adds support for collective communications in XLA devices (TPU). It is simply implemented by falling back to
xm.all_reduce
andxm.all_gather
for TPU devices. One difference is the gather operation in logits processor, wheregather
is replaced byall-gather
to meet the SPMD restriction in XLA.