Skip to content

Commit

Permalink
[core] platform agnostic executor via collective_rpc (vllm-project#11256
Browse files Browse the repository at this point in the history
)

Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao authored Jan 15, 2025
1 parent f218f9c commit ad34c0d
Show file tree
Hide file tree
Showing 43 changed files with 852 additions and 2,642 deletions.
28 changes: 12 additions & 16 deletions tests/engine/test_custom_executor.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,34 @@
import asyncio
import os
from typing import Any, Dict, List, Optional, Tuple

import pytest

from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.llm_engine import LLMEngine
from vllm.executor.gpu_executor import GPUExecutor, GPUExecutorAsync
from vllm.executor.uniproc_executor import UniProcExecutor
from vllm.sampling_params import SamplingParams


class Mock:
...


class CustomGPUExecutor(GPUExecutor):
class CustomUniExecutor(UniProcExecutor):

def execute_model(self, *args, **kwargs):
def collective_rpc(self,
method: str,
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]:
# Drop marker to show that this was ran
with open(".marker", "w"):
...
return super().execute_model(*args, **kwargs)
return super().collective_rpc(method, timeout, args, kwargs)


class CustomGPUExecutorAsync(GPUExecutorAsync):

async def execute_model_async(self, *args, **kwargs):
with open(".marker", "w"):
...
return await super().execute_model_async(*args, **kwargs)
CustomUniExecutorAsync = CustomUniExecutor


@pytest.mark.parametrize("model", ["facebook/opt-125m"])
Expand All @@ -41,10 +41,6 @@ def test_custom_executor_type_checking(model):
engine_args = AsyncEngineArgs(model=model,
distributed_executor_backend=Mock)
AsyncLLMEngine.from_engine_args(engine_args)
with pytest.raises(TypeError):
engine_args = AsyncEngineArgs(
model=model, distributed_executor_backend=CustomGPUExecutor)
AsyncLLMEngine.from_engine_args(engine_args)


@pytest.mark.parametrize("model", ["facebook/opt-125m"])
Expand All @@ -55,7 +51,7 @@ def test_custom_executor(model, tmp_path):
assert not os.path.exists(".marker")

engine_args = EngineArgs(
model=model, distributed_executor_backend=CustomGPUExecutor)
model=model, distributed_executor_backend=CustomUniExecutor)
engine = LLMEngine.from_engine_args(engine_args)
sampling_params = SamplingParams(max_tokens=1)

Expand All @@ -75,7 +71,7 @@ def test_custom_executor_async(model, tmp_path):
assert not os.path.exists(".marker")

engine_args = AsyncEngineArgs(
model=model, distributed_executor_backend=CustomGPUExecutorAsync)
model=model, distributed_executor_backend=CustomUniExecutorAsync)
engine = AsyncLLMEngine.from_engine_args(engine_args)
sampling_params = SamplingParams(max_tokens=1)

Expand Down
12 changes: 6 additions & 6 deletions tests/engine/test_multiproc_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,15 @@

import pytest

from vllm.config import VllmConfig
from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
ResultHandler, WorkerMonitor)
from vllm.worker.worker_base import WorkerWrapperBase


class DummyWorker:
class DummyWorkerWrapper(WorkerWrapperBase):
"""Dummy version of vllm.worker.worker.Worker"""

def __init__(self, rank: int):
self.rank = rank

def worker_method(self, worker_input: Any) -> Tuple[int, Any]:
sleep(0.05)

Expand All @@ -28,9 +27,10 @@ def worker_method(self, worker_input: Any) -> Tuple[int, Any]:

def _start_workers() -> Tuple[List[ProcessWorkerWrapper], WorkerMonitor]:
result_handler = ResultHandler()
vllm_config = VllmConfig()
workers = [
ProcessWorkerWrapper(result_handler, partial(DummyWorker, rank=rank))
for rank in range(8)
ProcessWorkerWrapper(result_handler, DummyWorkerWrapper, vllm_config,
rank) for rank in range(8)
]

worker_monitor = WorkerMonitor(workers, result_handler)
Expand Down
6 changes: 5 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import socket
from typing import AsyncIterator, Tuple
from unittest.mock import patch

import pytest
import torch
Expand Down Expand Up @@ -390,7 +391,10 @@ def test_bind_kv_cache_encoder_decoder():


def test_bind_kv_cache_pp():
cfg = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=2))
with patch("vllm.utils.cuda_device_count_stateless", lambda: 2):
# this test runs with 1 GPU, but we simulate 2 GPUs
cfg = VllmConfig(
parallel_config=ParallelConfig(pipeline_parallel_size=2))
with set_current_vllm_config(cfg):
from vllm.attention import Attention

Expand Down
12 changes: 8 additions & 4 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1294,8 +1294,11 @@ def __post_init__(self) -> None:
from vllm.executor import ray_utils
backend = "mp"
ray_found = ray_utils.ray_is_available()
if (current_platform.is_cuda()
and cuda_device_count_stateless() < self.world_size):
if current_platform.is_neuron():
# neuron uses single process to control multiple devices
backend = "uni"
elif (current_platform.is_cuda()
and cuda_device_count_stateless() < self.world_size):
if not ray_found:
raise ValueError("Unable to load Ray which is "
"required for multi-node inference, "
Expand Down Expand Up @@ -1328,13 +1331,14 @@ def _verify_args(self) -> None:
from vllm.executor.executor_base import ExecutorBase
from vllm.platforms import current_platform
if self.distributed_executor_backend not in (
"ray", "mp", None) and not (isinstance(
"ray", "mp", "uni", None) and not (isinstance(
self.distributed_executor_backend, type) and issubclass(
self.distributed_executor_backend, ExecutorBase)):
raise ValueError(
"Unrecognized distributed executor backend "
f"{self.distributed_executor_backend}. Supported "
"values are 'ray', 'mp' or custom ExecutorBase subclass.")
"values are 'ray', 'mp' 'uni', or custom ExecutorBase"
" subclass.")
if self.use_ray:
from vllm.executor import ray_utils
ray_utils.assert_ray_available()
Expand Down
6 changes: 4 additions & 2 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,12 +862,14 @@ def init_model_parallel_group(
) -> GroupCoordinator:
if use_custom_allreduce is None:
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
from vllm.platforms import current_platform
return GroupCoordinator(
group_ranks=group_ranks,
local_rank=local_rank,
torch_distributed_backend=backend,
use_pynccl=True,
use_custom_allreduce=use_custom_allreduce,
use_pynccl=current_platform.is_cuda_alike(),
use_custom_allreduce=current_platform.is_cuda_alike()
and use_custom_allreduce,
use_tpu_communicator=True,
use_hpu_communicator=True,
use_xpu_communicator=True,
Expand Down
88 changes: 6 additions & 82 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
from vllm.engine.metrics_types import StatLoggerBase
from vllm.engine.protocol import EngineClient
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutorAsync
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.executor.executor_base import ExecutorBase
from vllm.inputs import PromptType
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
Expand Down Expand Up @@ -620,69 +618,9 @@ def __del__(self):
rt.new_requests_event.set()

@classmethod
def _get_executor_cls(
cls, engine_config: VllmConfig) -> Type[ExecutorAsyncBase]:
distributed_executor_backend = (
engine_config.parallel_config.distributed_executor_backend)
if isinstance(distributed_executor_backend, type):
if not issubclass(distributed_executor_backend, ExecutorAsyncBase):
raise TypeError(
"distributed_executor_backend must be a subclass of "
f"ExecutorAsyncBase. Got {distributed_executor_backend}.")
executor_class = distributed_executor_backend
elif engine_config.device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutorAsync
executor_class = NeuronExecutorAsync
elif engine_config.device_config.device_type == "tpu":
if distributed_executor_backend == "ray":
from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync
executor_class = RayTPUExecutorAsync
else:
assert distributed_executor_backend is None
from vllm.executor.tpu_executor import TPUExecutorAsync
executor_class = TPUExecutorAsync
elif engine_config.device_config.device_type == "cpu":
from vllm.executor.cpu_executor import CPUExecutorAsync
executor_class = CPUExecutorAsync
elif engine_config.device_config.device_type == "hpu":
if distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_hpu_executor import RayHPUExecutorAsync
executor_class = RayHPUExecutorAsync
else:
from vllm.executor.hpu_executor import HPUExecutorAsync
executor_class = HPUExecutorAsync
elif engine_config.device_config.device_type == "openvino":
assert distributed_executor_backend is None, (
"Distributed execution is not supported with "
"the OpenVINO backend.")
from vllm.executor.openvino_executor import OpenVINOExecutorAsync
executor_class = OpenVINOExecutorAsync
elif engine_config.device_config.device_type == "xpu":
if distributed_executor_backend is None:
from vllm.executor.xpu_executor import XPUExecutorAsync
executor_class = XPUExecutorAsync
elif distributed_executor_backend == "ray":
from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync
executor_class = RayXPUExecutorAsync
elif distributed_executor_backend == "mp":
from vllm.executor.multiproc_xpu_executor import (
MultiprocessingXPUExecutorAsync)
executor_class = MultiprocessingXPUExecutorAsync
else:
raise RuntimeError(
"Not supported distributed execution model on XPU device.")
elif distributed_executor_backend == "ray":
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
executor_class = RayGPUExecutorAsync
elif distributed_executor_backend == "mp":
from vllm.executor.multiproc_gpu_executor import (
MultiprocessingGPUExecutorAsync)
executor_class = MultiprocessingGPUExecutorAsync
else:
from vllm.executor.gpu_executor import GPUExecutorAsync
executor_class = GPUExecutorAsync
return executor_class
def _get_executor_cls(cls,
engine_config: VllmConfig) -> Type[ExecutorBase]:
return LLMEngine._get_executor_cls(engine_config)

@classmethod
def from_engine_args(
Expand All @@ -700,9 +638,6 @@ def from_engine_args(

executor_class = cls._get_executor_cls(engine_config)

if executor_class.uses_ray:
initialize_ray_cluster(engine_config.parallel_config)

# Create the async LLM engine.
engine = cls(
vllm_config=engine_config,
Expand Down Expand Up @@ -1242,23 +1177,12 @@ def remove_logger(self, logger_name: str) -> None:
self.engine.remove_logger(logger_name=logger_name)

async def start_profile(self) -> None:
# using type instead of isinstance to check to avoid capturing
# inherited classes
if type(self.engine.model_executor) == GPUExecutorAsync: # noqa: E721
self.engine.model_executor.start_profile()
else:
self.engine.model_executor._run_workers("start_profile")
self.engine.start_profile()

async def stop_profile(self) -> None:
# using type instead of isinstance to check to avoid capturing
# inherited classes
if type(self.engine.model_executor) == GPUExecutorAsync: # noqa: E721
self.engine.model_executor.stop_profile()
else:
self.engine.model_executor._run_workers("stop_profile")
self.engine.stop_profile()

async def add_lora(self, lora_request: LoRARequest) -> None:
"""Load a new LoRA adapter into the engine for future requests."""
self.engine.add_lora(lora_request)


Expand Down
Loading

0 comments on commit ad34c0d

Please sign in to comment.