Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TPU] Support collective communications in XLA devices #6813

Merged
merged 6 commits into from
Jul 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions vllm/distributed/device_communicators/tpu_communicator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup

from vllm.platforms import current_platform

if current_platform.is_tpu():
import torch_xla.core.xla_model as xm
from torch_xla._internal import pjrt


class TpuCommunicator:

def __init__(self, group: ProcessGroup):
if not current_platform.is_tpu():
self.disabled = True
return
self.disabled = False

local_rank = dist.get_rank(group)
world_size = dist.get_world_size(group)
pjrt.initialize_multiprocess(local_rank, world_size)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can get rank and world size from the group

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for letting me know! Updated the PR. PTAL.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you need to remove the code inside tpu worker? I don't know if pjrt and xm support initialization for multiple times.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@youkaichao Good point. That's actually updated in #5871. In the current main branch, there's no code initializing the XLA's distributed runtime.

xm._init_world_size_ordinal()

def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
return xm.all_reduce(xm.REDUCE_SUM, x)

def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
assert dim == -1, "TPUs only support dim=-1 for all-gather."
return xm.all_gather(x, dim=dim)
22 changes: 22 additions & 0 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def __init__(
torch_distributed_backend: Union[str, Backend],
use_pynccl: bool,
use_custom_allreduce: bool,
use_tpu_communicator: bool,
use_message_queue_broadcaster: bool = False,
):

Expand Down Expand Up @@ -164,6 +165,7 @@ def __init__(

self.use_pynccl = use_pynccl
self.use_custom_allreduce = use_custom_allreduce
self.use_tpu_communicator = use_tpu_communicator

# lazy import to avoid documentation build error
from vllm.distributed.device_communicators.custom_all_reduce import (
Expand All @@ -190,6 +192,12 @@ def __init__(
else:
self.ca_comm = None

from vllm.distributed.device_communicators.tpu_communicator import (
TpuCommunicator)
self.tpu_communicator: Optional[TpuCommunicator]
if use_tpu_communicator and self.world_size > 1:
self.tpu_communicator = TpuCommunicator(group=self.cpu_group)

from vllm.distributed.device_communicators.shm_broadcast import (
MessageQueue)
self.mq_broadcaster: Optional[MessageQueue] = None
Expand Down Expand Up @@ -289,6 +297,12 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
# Bypass the function if we are using only 1 GPU.
if self.world_size == 1:
return input_

# For TPUs, use TPU communicator.
tpu_comm = self.tpu_communicator
if tpu_comm is not None and not tpu_comm.disabled:
return tpu_comm.all_reduce(input_)

if ca_comm is not None:
out = ca_comm.custom_all_reduce(input_)
if out is not None:
Expand All @@ -307,6 +321,12 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")

# For TPUs, use TPU communicator.
tpu_comm = self.tpu_communicator
if tpu_comm is not None and not tpu_comm.disabled:
return tpu_comm.all_gather(input_, dim)

if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
Expand Down Expand Up @@ -724,6 +744,7 @@ def init_world_group(ranks: List[int], local_rank: int,
torch_distributed_backend=backend,
use_pynccl=False,
use_custom_allreduce=False,
use_tpu_communicator=False,
)


Expand All @@ -742,6 +763,7 @@ def init_model_parallel_group(
torch_distributed_backend=backend,
use_pynccl=True,
use_custom_allreduce=use_custom_allreduce,
use_tpu_communicator=True,
use_message_queue_broadcaster=use_message_queue_broadcaster,
)

Expand Down
4 changes: 4 additions & 0 deletions vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,6 +1067,10 @@ def scale(self):
def soft_cap(self):
return self.base_layer.soft_cap

@property
def use_gather(self):
return self.base_layer.use_gather

WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved
@property
def org_vocab_size(self):
return self.base_layer.org_vocab_size
Expand Down
16 changes: 14 additions & 2 deletions vllm/model_executor/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
import torch
import torch.nn as nn

from vllm.distributed import tensor_model_parallel_gather
from vllm.distributed import (tensor_model_parallel_all_gather,
tensor_model_parallel_gather)
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform


class LogitsProcessor(nn.Module):
Expand Down Expand Up @@ -39,6 +41,8 @@ def __init__(self,
self.org_vocab_size = org_vocab_size or vocab_size
# Soft cap the logits. Used in Gemma 2.
self.soft_cap = soft_cap
# Whether to use gather or all-gather to gather the logits.
self.use_gather = not current_platform.is_tpu()

def forward(
self,
Expand Down Expand Up @@ -76,7 +80,15 @@ def _get_logits(self, hidden_states: torch.Tensor,
logits = lm_head.linear_method.apply(lm_head,
hidden_states,
bias=embedding_bias)
logits = tensor_model_parallel_gather(logits)
if self.use_gather:
logits = tensor_model_parallel_gather(logits)
else:
# Gather is not supported for some devices such as TPUs.
# Use all-gather instead.
# NOTE(woosuk): Here, the outputs of every device should not be None
# because XLA requires strict SPMD among all devices. Every device
# should execute the same operations after gathering the logits.
logits = tensor_model_parallel_all_gather(logits)
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
# Remove paddings in vocab (if any).
if logits is not None:
logits = logits[:, :self.org_vocab_size]
Expand Down
Loading