From 164bbdf19b02db096c0b1261ceb0802eeab3ad4d Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 26 Jul 2024 18:45:57 -0700 Subject: [PATCH] [TPU] Support collective communications in XLA devices (#6813) --- .../device_communicators/tpu_communicator.py | 30 +++++++++++++++++++ vllm/distributed/parallel_state.py | 22 ++++++++++++++ vllm/lora/layers.py | 4 +++ .../model_executor/layers/logits_processor.py | 16 ++++++++-- 4 files changed, 70 insertions(+), 2 deletions(-) create mode 100644 vllm/distributed/device_communicators/tpu_communicator.py diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py new file mode 100644 index 0000000000000..69a9a516f3ebe --- /dev/null +++ b/vllm/distributed/device_communicators/tpu_communicator.py @@ -0,0 +1,30 @@ +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +from vllm.platforms import current_platform + +if current_platform.is_tpu(): + import torch_xla.core.xla_model as xm + from torch_xla._internal import pjrt + + +class TpuCommunicator: + + def __init__(self, group: ProcessGroup): + if not current_platform.is_tpu(): + self.disabled = True + return + self.disabled = False + + local_rank = dist.get_rank(group) + world_size = dist.get_world_size(group) + pjrt.initialize_multiprocess(local_rank, world_size) + xm._init_world_size_ordinal() + + def all_reduce(self, x: torch.Tensor) -> torch.Tensor: + return xm.all_reduce(xm.REDUCE_SUM, x) + + def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + assert dim == -1, "TPUs only support dim=-1 for all-gather." + return xm.all_gather(x, dim=dim) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 58cae46d9af27..4116b1729d188 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -133,6 +133,7 @@ def __init__( torch_distributed_backend: Union[str, Backend], use_pynccl: bool, use_custom_allreduce: bool, + use_tpu_communicator: bool, use_message_queue_broadcaster: bool = False, ): @@ -164,6 +165,7 @@ def __init__( self.use_pynccl = use_pynccl self.use_custom_allreduce = use_custom_allreduce + self.use_tpu_communicator = use_tpu_communicator # lazy import to avoid documentation build error from vllm.distributed.device_communicators.custom_all_reduce import ( @@ -190,6 +192,12 @@ def __init__( else: self.ca_comm = None + from vllm.distributed.device_communicators.tpu_communicator import ( + TpuCommunicator) + self.tpu_communicator: Optional[TpuCommunicator] + if use_tpu_communicator and self.world_size > 1: + self.tpu_communicator = TpuCommunicator(group=self.cpu_group) + from vllm.distributed.device_communicators.shm_broadcast import ( MessageQueue) self.mq_broadcaster: Optional[MessageQueue] = None @@ -289,6 +297,12 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: # Bypass the function if we are using only 1 GPU. if self.world_size == 1: return input_ + + # For TPUs, use TPU communicator. + tpu_comm = self.tpu_communicator + if tpu_comm is not None and not tpu_comm.disabled: + return tpu_comm.all_reduce(input_) + if ca_comm is not None: out = ca_comm.custom_all_reduce(input_) if out is not None: @@ -310,6 +324,12 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: return input_ assert -input_.dim() <= dim < input_.dim(), ( f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + + # For TPUs, use TPU communicator. + tpu_comm = self.tpu_communicator + if tpu_comm is not None and not tpu_comm.disabled: + return tpu_comm.all_gather(input_, dim) + if dim < 0: # Convert negative dim to positive. dim += input_.dim() @@ -727,6 +747,7 @@ def init_world_group(ranks: List[int], local_rank: int, torch_distributed_backend=backend, use_pynccl=False, use_custom_allreduce=False, + use_tpu_communicator=False, ) @@ -745,6 +766,7 @@ def init_model_parallel_group( torch_distributed_backend=backend, use_pynccl=True, use_custom_allreduce=use_custom_allreduce, + use_tpu_communicator=True, use_message_queue_broadcaster=use_message_queue_broadcaster, ) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 40de134c0a5ee..87de285a373a2 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1067,6 +1067,10 @@ def scale(self): def soft_cap(self): return self.base_layer.soft_cap + @property + def use_gather(self): + return self.base_layer.use_gather + @property def org_vocab_size(self): return self.base_layer.org_vocab_size diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index f6fcf49ef464b..bd3e7e114204f 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -5,10 +5,12 @@ import torch import torch.nn as nn -from vllm.distributed import tensor_model_parallel_gather +from vllm.distributed import (tensor_model_parallel_all_gather, + tensor_model_parallel_gather) from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.platforms import current_platform class LogitsProcessor(nn.Module): @@ -39,6 +41,8 @@ def __init__(self, self.org_vocab_size = org_vocab_size or vocab_size # Soft cap the logits. Used in Gemma 2. self.soft_cap = soft_cap + # Whether to use gather or all-gather to gather the logits. + self.use_gather = not current_platform.is_tpu() def forward( self, @@ -76,7 +80,15 @@ def _get_logits(self, hidden_states: torch.Tensor, logits = lm_head.linear_method.apply(lm_head, hidden_states, bias=embedding_bias) - logits = tensor_model_parallel_gather(logits) + if self.use_gather: + logits = tensor_model_parallel_gather(logits) + else: + # Gather is not supported for some devices such as TPUs. + # Use all-gather instead. + # NOTE(woosuk): Here, the outputs of every device should not be None + # because XLA requires strict SPMD among all devices. Every device + # should execute the same operations after gathering the logits. + logits = tensor_model_parallel_all_gather(logits) # Remove paddings in vocab (if any). if logits is not None: logits = logits[:, :self.org_vocab_size]