Skip to content

Commit

Permalink
[TPU] Support collective communications in XLA devices (vllm-project#…
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Jul 27, 2024
1 parent 7c1b18a commit 164bbdf
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 2 deletions.
30 changes: 30 additions & 0 deletions vllm/distributed/device_communicators/tpu_communicator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup

from vllm.platforms import current_platform

if current_platform.is_tpu():
import torch_xla.core.xla_model as xm
from torch_xla._internal import pjrt


class TpuCommunicator:

def __init__(self, group: ProcessGroup):
if not current_platform.is_tpu():
self.disabled = True
return
self.disabled = False

local_rank = dist.get_rank(group)
world_size = dist.get_world_size(group)
pjrt.initialize_multiprocess(local_rank, world_size)
xm._init_world_size_ordinal()

def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
return xm.all_reduce(xm.REDUCE_SUM, x)

def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
assert dim == -1, "TPUs only support dim=-1 for all-gather."
return xm.all_gather(x, dim=dim)
22 changes: 22 additions & 0 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def __init__(
torch_distributed_backend: Union[str, Backend],
use_pynccl: bool,
use_custom_allreduce: bool,
use_tpu_communicator: bool,
use_message_queue_broadcaster: bool = False,
):

Expand Down Expand Up @@ -164,6 +165,7 @@ def __init__(

self.use_pynccl = use_pynccl
self.use_custom_allreduce = use_custom_allreduce
self.use_tpu_communicator = use_tpu_communicator

# lazy import to avoid documentation build error
from vllm.distributed.device_communicators.custom_all_reduce import (
Expand All @@ -190,6 +192,12 @@ def __init__(
else:
self.ca_comm = None

from vllm.distributed.device_communicators.tpu_communicator import (
TpuCommunicator)
self.tpu_communicator: Optional[TpuCommunicator]
if use_tpu_communicator and self.world_size > 1:
self.tpu_communicator = TpuCommunicator(group=self.cpu_group)

from vllm.distributed.device_communicators.shm_broadcast import (
MessageQueue)
self.mq_broadcaster: Optional[MessageQueue] = None
Expand Down Expand Up @@ -289,6 +297,12 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
# Bypass the function if we are using only 1 GPU.
if self.world_size == 1:
return input_

# For TPUs, use TPU communicator.
tpu_comm = self.tpu_communicator
if tpu_comm is not None and not tpu_comm.disabled:
return tpu_comm.all_reduce(input_)

if ca_comm is not None:
out = ca_comm.custom_all_reduce(input_)
if out is not None:
Expand All @@ -310,6 +324,12 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")

# For TPUs, use TPU communicator.
tpu_comm = self.tpu_communicator
if tpu_comm is not None and not tpu_comm.disabled:
return tpu_comm.all_gather(input_, dim)

if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
Expand Down Expand Up @@ -727,6 +747,7 @@ def init_world_group(ranks: List[int], local_rank: int,
torch_distributed_backend=backend,
use_pynccl=False,
use_custom_allreduce=False,
use_tpu_communicator=False,
)


Expand All @@ -745,6 +766,7 @@ def init_model_parallel_group(
torch_distributed_backend=backend,
use_pynccl=True,
use_custom_allreduce=use_custom_allreduce,
use_tpu_communicator=True,
use_message_queue_broadcaster=use_message_queue_broadcaster,
)

Expand Down
4 changes: 4 additions & 0 deletions vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,6 +1067,10 @@ def scale(self):
def soft_cap(self):
return self.base_layer.soft_cap

@property
def use_gather(self):
return self.base_layer.use_gather

@property
def org_vocab_size(self):
return self.base_layer.org_vocab_size
Expand Down
16 changes: 14 additions & 2 deletions vllm/model_executor/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
import torch
import torch.nn as nn

from vllm.distributed import tensor_model_parallel_gather
from vllm.distributed import (tensor_model_parallel_all_gather,
tensor_model_parallel_gather)
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform


class LogitsProcessor(nn.Module):
Expand Down Expand Up @@ -39,6 +41,8 @@ def __init__(self,
self.org_vocab_size = org_vocab_size or vocab_size
# Soft cap the logits. Used in Gemma 2.
self.soft_cap = soft_cap
# Whether to use gather or all-gather to gather the logits.
self.use_gather = not current_platform.is_tpu()

def forward(
self,
Expand Down Expand Up @@ -76,7 +80,15 @@ def _get_logits(self, hidden_states: torch.Tensor,
logits = lm_head.linear_method.apply(lm_head,
hidden_states,
bias=embedding_bias)
logits = tensor_model_parallel_gather(logits)
if self.use_gather:
logits = tensor_model_parallel_gather(logits)
else:
# Gather is not supported for some devices such as TPUs.
# Use all-gather instead.
# NOTE(woosuk): Here, the outputs of every device should not be None
# because XLA requires strict SPMD among all devices. Every device
# should execute the same operations after gathering the logits.
logits = tensor_model_parallel_all_gather(logits)
# Remove paddings in vocab (if any).
if logits is not None:
logits = logits[:, :self.org_vocab_size]
Expand Down

0 comments on commit 164bbdf

Please sign in to comment.