diff --git a/tests/engine/test_custom_executor.py b/tests/engine/test_custom_executor.py index bbabb936e92ba..2a057ca488a50 100644 --- a/tests/engine/test_custom_executor.py +++ b/tests/engine/test_custom_executor.py @@ -1,12 +1,13 @@ import asyncio import os +from typing import Any, Dict, List, Optional, Tuple import pytest from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.llm_engine import LLMEngine -from vllm.executor.gpu_executor import GPUExecutor, GPUExecutorAsync +from vllm.executor.uniproc_executor import UniProcExecutor from vllm.sampling_params import SamplingParams @@ -14,21 +15,20 @@ class Mock: ... -class CustomGPUExecutor(GPUExecutor): +class CustomUniExecutor(UniProcExecutor): - def execute_model(self, *args, **kwargs): + def collective_rpc(self, + method: str, + timeout: Optional[float] = None, + args: Tuple = (), + kwargs: Optional[Dict] = None) -> List[Any]: # Drop marker to show that this was ran with open(".marker", "w"): ... - return super().execute_model(*args, **kwargs) + return super().collective_rpc(method, timeout, args, kwargs) -class CustomGPUExecutorAsync(GPUExecutorAsync): - - async def execute_model_async(self, *args, **kwargs): - with open(".marker", "w"): - ... - return await super().execute_model_async(*args, **kwargs) +CustomUniExecutorAsync = CustomUniExecutor @pytest.mark.parametrize("model", ["facebook/opt-125m"]) @@ -41,10 +41,6 @@ def test_custom_executor_type_checking(model): engine_args = AsyncEngineArgs(model=model, distributed_executor_backend=Mock) AsyncLLMEngine.from_engine_args(engine_args) - with pytest.raises(TypeError): - engine_args = AsyncEngineArgs( - model=model, distributed_executor_backend=CustomGPUExecutor) - AsyncLLMEngine.from_engine_args(engine_args) @pytest.mark.parametrize("model", ["facebook/opt-125m"]) @@ -55,7 +51,7 @@ def test_custom_executor(model, tmp_path): assert not os.path.exists(".marker") engine_args = EngineArgs( - model=model, distributed_executor_backend=CustomGPUExecutor) + model=model, distributed_executor_backend=CustomUniExecutor) engine = LLMEngine.from_engine_args(engine_args) sampling_params = SamplingParams(max_tokens=1) @@ -75,7 +71,7 @@ def test_custom_executor_async(model, tmp_path): assert not os.path.exists(".marker") engine_args = AsyncEngineArgs( - model=model, distributed_executor_backend=CustomGPUExecutorAsync) + model=model, distributed_executor_backend=CustomUniExecutorAsync) engine = AsyncLLMEngine.from_engine_args(engine_args) sampling_params = SamplingParams(max_tokens=1) diff --git a/tests/engine/test_multiproc_workers.py b/tests/engine/test_multiproc_workers.py index e07dd6deef5bf..db70a808c008b 100644 --- a/tests/engine/test_multiproc_workers.py +++ b/tests/engine/test_multiproc_workers.py @@ -6,16 +6,15 @@ import pytest +from vllm.config import VllmConfig from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper, ResultHandler, WorkerMonitor) +from vllm.worker.worker_base import WorkerWrapperBase -class DummyWorker: +class DummyWorkerWrapper(WorkerWrapperBase): """Dummy version of vllm.worker.worker.Worker""" - def __init__(self, rank: int): - self.rank = rank - def worker_method(self, worker_input: Any) -> Tuple[int, Any]: sleep(0.05) @@ -28,9 +27,10 @@ def worker_method(self, worker_input: Any) -> Tuple[int, Any]: def _start_workers() -> Tuple[List[ProcessWorkerWrapper], WorkerMonitor]: result_handler = ResultHandler() + vllm_config = VllmConfig() workers = [ - ProcessWorkerWrapper(result_handler, partial(DummyWorker, rank=rank)) - for rank in range(8) + ProcessWorkerWrapper(result_handler, DummyWorkerWrapper, vllm_config, + rank) for rank in range(8) ] worker_monitor = WorkerMonitor(workers, result_handler) diff --git a/tests/test_utils.py b/tests/test_utils.py index 6810e0302f897..c68d730af7f8a 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,6 +2,7 @@ import os import socket from typing import AsyncIterator, Tuple +from unittest.mock import patch import pytest import torch @@ -390,7 +391,10 @@ def test_bind_kv_cache_encoder_decoder(): def test_bind_kv_cache_pp(): - cfg = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=2)) + with patch("vllm.utils.cuda_device_count_stateless", lambda: 2): + # this test runs with 1 GPU, but we simulate 2 GPUs + cfg = VllmConfig( + parallel_config=ParallelConfig(pipeline_parallel_size=2)) with set_current_vllm_config(cfg): from vllm.attention import Attention diff --git a/vllm/config.py b/vllm/config.py index 59b509d5a961e..4a42aefb75026 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1294,8 +1294,11 @@ def __post_init__(self) -> None: from vllm.executor import ray_utils backend = "mp" ray_found = ray_utils.ray_is_available() - if (current_platform.is_cuda() - and cuda_device_count_stateless() < self.world_size): + if current_platform.is_neuron(): + # neuron uses single process to control multiple devices + backend = "uni" + elif (current_platform.is_cuda() + and cuda_device_count_stateless() < self.world_size): if not ray_found: raise ValueError("Unable to load Ray which is " "required for multi-node inference, " @@ -1328,13 +1331,14 @@ def _verify_args(self) -> None: from vllm.executor.executor_base import ExecutorBase from vllm.platforms import current_platform if self.distributed_executor_backend not in ( - "ray", "mp", None) and not (isinstance( + "ray", "mp", "uni", None) and not (isinstance( self.distributed_executor_backend, type) and issubclass( self.distributed_executor_backend, ExecutorBase)): raise ValueError( "Unrecognized distributed executor backend " f"{self.distributed_executor_backend}. Supported " - "values are 'ray', 'mp' or custom ExecutorBase subclass.") + "values are 'ray', 'mp' 'uni', or custom ExecutorBase" + " subclass.") if self.use_ray: from vllm.executor import ray_utils ray_utils.assert_ray_available() diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index be7f16ef52a47..bf8b30cccd5f6 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -862,12 +862,14 @@ def init_model_parallel_group( ) -> GroupCoordinator: if use_custom_allreduce is None: use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE + from vllm.platforms import current_platform return GroupCoordinator( group_ranks=group_ranks, local_rank=local_rank, torch_distributed_backend=backend, - use_pynccl=True, - use_custom_allreduce=use_custom_allreduce, + use_pynccl=current_platform.is_cuda_alike(), + use_custom_allreduce=current_platform.is_cuda_alike() + and use_custom_allreduce, use_tpu_communicator=True, use_hpu_communicator=True, use_xpu_communicator=True, diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index da23ed19ef7be..08fef8250d483 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -18,9 +18,7 @@ from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState from vllm.engine.metrics_types import StatLoggerBase from vllm.engine.protocol import EngineClient -from vllm.executor.executor_base import ExecutorAsyncBase -from vllm.executor.gpu_executor import GPUExecutorAsync -from vllm.executor.ray_utils import initialize_ray_cluster +from vllm.executor.executor_base import ExecutorBase from vllm.inputs import PromptType from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger @@ -620,69 +618,9 @@ def __del__(self): rt.new_requests_event.set() @classmethod - def _get_executor_cls( - cls, engine_config: VllmConfig) -> Type[ExecutorAsyncBase]: - distributed_executor_backend = ( - engine_config.parallel_config.distributed_executor_backend) - if isinstance(distributed_executor_backend, type): - if not issubclass(distributed_executor_backend, ExecutorAsyncBase): - raise TypeError( - "distributed_executor_backend must be a subclass of " - f"ExecutorAsyncBase. Got {distributed_executor_backend}.") - executor_class = distributed_executor_backend - elif engine_config.device_config.device_type == "neuron": - from vllm.executor.neuron_executor import NeuronExecutorAsync - executor_class = NeuronExecutorAsync - elif engine_config.device_config.device_type == "tpu": - if distributed_executor_backend == "ray": - from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync - executor_class = RayTPUExecutorAsync - else: - assert distributed_executor_backend is None - from vllm.executor.tpu_executor import TPUExecutorAsync - executor_class = TPUExecutorAsync - elif engine_config.device_config.device_type == "cpu": - from vllm.executor.cpu_executor import CPUExecutorAsync - executor_class = CPUExecutorAsync - elif engine_config.device_config.device_type == "hpu": - if distributed_executor_backend == "ray": - initialize_ray_cluster(engine_config.parallel_config) - from vllm.executor.ray_hpu_executor import RayHPUExecutorAsync - executor_class = RayHPUExecutorAsync - else: - from vllm.executor.hpu_executor import HPUExecutorAsync - executor_class = HPUExecutorAsync - elif engine_config.device_config.device_type == "openvino": - assert distributed_executor_backend is None, ( - "Distributed execution is not supported with " - "the OpenVINO backend.") - from vllm.executor.openvino_executor import OpenVINOExecutorAsync - executor_class = OpenVINOExecutorAsync - elif engine_config.device_config.device_type == "xpu": - if distributed_executor_backend is None: - from vllm.executor.xpu_executor import XPUExecutorAsync - executor_class = XPUExecutorAsync - elif distributed_executor_backend == "ray": - from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync - executor_class = RayXPUExecutorAsync - elif distributed_executor_backend == "mp": - from vllm.executor.multiproc_xpu_executor import ( - MultiprocessingXPUExecutorAsync) - executor_class = MultiprocessingXPUExecutorAsync - else: - raise RuntimeError( - "Not supported distributed execution model on XPU device.") - elif distributed_executor_backend == "ray": - from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync - executor_class = RayGPUExecutorAsync - elif distributed_executor_backend == "mp": - from vllm.executor.multiproc_gpu_executor import ( - MultiprocessingGPUExecutorAsync) - executor_class = MultiprocessingGPUExecutorAsync - else: - from vllm.executor.gpu_executor import GPUExecutorAsync - executor_class = GPUExecutorAsync - return executor_class + def _get_executor_cls(cls, + engine_config: VllmConfig) -> Type[ExecutorBase]: + return LLMEngine._get_executor_cls(engine_config) @classmethod def from_engine_args( @@ -700,9 +638,6 @@ def from_engine_args( executor_class = cls._get_executor_cls(engine_config) - if executor_class.uses_ray: - initialize_ray_cluster(engine_config.parallel_config) - # Create the async LLM engine. engine = cls( vllm_config=engine_config, @@ -1242,23 +1177,12 @@ def remove_logger(self, logger_name: str) -> None: self.engine.remove_logger(logger_name=logger_name) async def start_profile(self) -> None: - # using type instead of isinstance to check to avoid capturing - # inherited classes - if type(self.engine.model_executor) == GPUExecutorAsync: # noqa: E721 - self.engine.model_executor.start_profile() - else: - self.engine.model_executor._run_workers("start_profile") + self.engine.start_profile() async def stop_profile(self) -> None: - # using type instead of isinstance to check to avoid capturing - # inherited classes - if type(self.engine.model_executor) == GPUExecutorAsync: # noqa: E721 - self.engine.model_executor.stop_profile() - else: - self.engine.model_executor._run_workers("stop_profile") + self.engine.stop_profile() async def add_lora(self, lora_request: LoRARequest) -> None: - """Load a new LoRA adapter into the engine for future requests.""" self.engine.add_lora(lora_request) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1db3e59ff3bae..49a1e9f505d9f 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -28,8 +28,6 @@ from vllm.entrypoints.openai.logits_processors import ( get_logits_processors as get_openai_logits_processors) from vllm.executor.executor_base import ExecutorBase -from vllm.executor.gpu_executor import GPUExecutor -from vllm.executor.ray_utils import initialize_ray_cluster from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, PromptType, SingletonInputsAdapter) from vllm.inputs.parse import is_encoder_decoder_inputs, is_token_prompt @@ -442,64 +440,26 @@ def _get_executor_cls(cls, raise TypeError( "distributed_executor_backend must be a subclass of " f"ExecutorBase. Got {distributed_executor_backend}.") - if distributed_executor_backend.uses_ray: # type: ignore - initialize_ray_cluster(engine_config.parallel_config) executor_class = distributed_executor_backend - elif engine_config.device_config.device_type == "neuron": - from vllm.executor.neuron_executor import NeuronExecutor - executor_class = NeuronExecutor - elif engine_config.device_config.device_type == "tpu": + elif engine_config.parallel_config.world_size > 1: if distributed_executor_backend == "ray": - initialize_ray_cluster(engine_config.parallel_config) - from vllm.executor.ray_tpu_executor import RayTPUExecutor - executor_class = RayTPUExecutor - else: - assert distributed_executor_backend is None - from vllm.executor.tpu_executor import TPUExecutor - executor_class = TPUExecutor - elif engine_config.device_config.device_type == "cpu": - from vllm.executor.cpu_executor import CPUExecutor - executor_class = CPUExecutor - elif engine_config.device_config.device_type == "hpu": - if distributed_executor_backend == "ray": - initialize_ray_cluster(engine_config.parallel_config) - from vllm.executor.ray_hpu_executor import RayHPUExecutor - executor_class = RayHPUExecutor - else: - from vllm.executor.hpu_executor import HPUExecutor - executor_class = HPUExecutor - elif engine_config.device_config.device_type == "openvino": - from vllm.executor.openvino_executor import OpenVINOExecutor - executor_class = OpenVINOExecutor - elif engine_config.device_config.device_type == "xpu": - if distributed_executor_backend == "ray": - initialize_ray_cluster(engine_config.parallel_config) - from vllm.executor.ray_xpu_executor import RayXPUExecutor - executor_class = RayXPUExecutor + from vllm.executor.ray_distributed_executor import ( + RayDistributedExecutor) + executor_class = RayDistributedExecutor elif distributed_executor_backend == "mp": - # FIXME(kunshang): - # spawn needs calling `if __name__ == '__main__':`` - # fork is not supported for xpu start new process. - logger.error( - "Both start methods (spawn and fork) have issue " - "on XPU if you use mp backend, Please try ray instead.") - else: - from vllm.executor.xpu_executor import XPUExecutor - executor_class = XPUExecutor - elif distributed_executor_backend == "ray": - initialize_ray_cluster(engine_config.parallel_config) - from vllm.executor.ray_gpu_executor import RayGPUExecutor - executor_class = RayGPUExecutor - elif distributed_executor_backend == "mp": - from vllm.executor.multiproc_gpu_executor import ( - MultiprocessingGPUExecutor) - assert not envs.VLLM_USE_RAY_SPMD_WORKER, ( - "multiprocessing distributed executor backend does not " - "support VLLM_USE_RAY_SPMD_WORKER=1") - executor_class = MultiprocessingGPUExecutor + from vllm.executor.mp_distributed_executor import ( + MultiprocessingDistributedExecutor) + assert not envs.VLLM_USE_RAY_SPMD_WORKER, ( + "multiprocessing distributed executor backend does not " + "support VLLM_USE_RAY_SPMD_WORKER=1") + executor_class = MultiprocessingDistributedExecutor + elif distributed_executor_backend == "uni": + # JAX-style, single-process, multi-device executor. + from vllm.executor.uniproc_executor import UniProcExecutor + executor_class = UniProcExecutor else: - from vllm.executor.gpu_executor import GPUExecutor - executor_class = GPUExecutor + from vllm.executor.uniproc_executor import UniProcExecutor + executor_class = UniProcExecutor return executor_class @classmethod @@ -1845,27 +1805,17 @@ def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: def list_prompt_adapters(self) -> List[int]: return self.model_executor.list_prompt_adapters() + def start_profile(self) -> None: + self.model_executor.start_profile() + + def stop_profile(self) -> None: + self.model_executor.stop_profile() + def check_health(self) -> None: if self.tokenizer: self.tokenizer.check_health() self.model_executor.check_health() - def start_profile(self) -> None: - # using type instead of isinstance to check to avoid capturing - # inherited classes (MultiprocessingGPUExecutor) - if type(self.model_executor) == GPUExecutor: # noqa: E721 - self.model_executor.start_profile() - else: - self.model_executor._run_workers("start_profile") - - def stop_profile(self) -> None: - # using type instead of isinstance to check to avoid capturing - # inherited classes (MultiprocessingGPUExecutor) - if type(self.model_executor) == GPUExecutor: # noqa: E721 - self.model_executor.stop_profile() - else: - self.model_executor._run_workers("stop_profile") - def is_tracing_enabled(self) -> bool: return self.tracer is not None diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 36f4df4b02731..8f231de912c95 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -20,7 +20,6 @@ RPCStartupResponse, RPCUProfileRequest) # yapf: enable -from vllm.executor.gpu_executor import GPUExecutor from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.usage.usage_lib import UsageContext @@ -356,16 +355,10 @@ def _set_errored(self, e: BaseException): self._errored_with = e def start_profile(self) -> None: - if type(self.engine.model_executor) is GPUExecutor: - self.engine.model_executor.start_profile() - else: - self.engine.model_executor._run_workers("start_profile") + self.engine.start_profile() def stop_profile(self) -> None: - if type(self.engine.model_executor) is GPUExecutor: - self.engine.model_executor.stop_profile() - else: - self.engine.model_executor._run_workers("stop_profile") + self.engine.stop_profile() def signal_handler(*_) -> None: diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py deleted file mode 100644 index b9a6bee5720fd..0000000000000 --- a/vllm/executor/cpu_executor.py +++ /dev/null @@ -1,299 +0,0 @@ -import os -from functools import partial -from typing import Any, Awaitable, List, Optional, Set, Tuple, Union - -from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase -from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper, - ResultHandler, WorkerMonitor) -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sequence import ExecuteModelRequest -from vllm.utils import get_distributed_init_method, get_open_port, make_async -from vllm.worker.worker_base import WorkerWrapperBase - -logger = init_logger(__name__) - - -class CPUExecutor(ExecutorBase): - - uses_ray: bool = False - - def _init_executor(self) -> None: - assert self.device_config.device_type == "cpu" - - # - # Environment variables for CPU executor - # - - # Disable torch async compiling which won't work with daemonic processes - os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" - - # Intel OpenMP setting - ld_prealod_str = os.getenv("LD_PRELOAD", "") - if "libiomp5.so" in ld_prealod_str: - # The time(milliseconds) that a thread should wait after - # completing the execution of a parallel region, before sleeping. - os.environ['KMP_BLOCKTIME'] = "1" - # Prevents the CPU to run into low performance state - os.environ['KMP_TPAUSE'] = "0" - # Provides fine granularity parallelism - os.environ['KMP_FORKJOIN_BARRIER_PATTERN'] = "dist,dist" - os.environ['KMP_PLAIN_BARRIER_PATTERN'] = "dist,dist" - os.environ['KMP_REDUCTION_BARRIER_PATTERN'] = "dist,dist" - - # To hint IPEX uses shared memory based AllReduce - os.environ["LOCAL_WORLD_SIZE"] = str( - self.parallel_config.tensor_parallel_size) - - # Multiprocessing-based executor does not support multi-node setting. - # Since it only works for single node, we can use the loopback address - # 127.0.0.1 for communication. - ip = "127.0.0.1" - port = get_open_port() - self.distributed_init_method = get_distributed_init_method(ip, port) - - is_async = isinstance(self, CPUExecutorAsync) - - world_size = self.parallel_config.tensor_parallel_size - result_handler = ResultHandler() - self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None - self.workers = [] - - if is_async: - self.workers = [ - ProcessWorkerWrapper( - result_handler, - partial( - self._create_worker, - rank=rank, - local_rank=rank, - )) for rank in range(0, world_size) - ] - self.driver_worker = self.workers[0] - self.workers = self.workers[1:] - self.driver_method_invoker = _async_driver_method_invoker - else: - self.driver_worker = self._create_worker() - self.driver_method_invoker = _driver_method_invoker - - if world_size != 1: - self.workers = [ - ProcessWorkerWrapper( - result_handler, - partial( - self._create_worker, - rank=rank, - local_rank=rank, - )) for rank in range(1, world_size) - ] - - self.worker_monitor = None - if world_size != 1 or is_async: - if is_async: - async_worker_list = self.workers + [self.driver_worker] - else: - async_worker_list = self.workers - self.worker_monitor = WorkerMonitor(async_worker_list, - result_handler) - result_handler.start() - self.worker_monitor.start() - - self._run_workers("init_device") - self._run_workers("load_model") - - def _create_worker( - self, - local_rank: int = 0, - rank: int = 0, - ): - - wrapper = WorkerWrapperBase(vllm_config=self.vllm_config) - - assert self.distributed_init_method is not None - - kwargs = dict( - vllm_config=self.vllm_config, - local_rank=local_rank, - rank=rank, - distributed_init_method=self.distributed_init_method, - kv_cache_dtype=self.cache_config.cache_dtype, - is_driver_worker=rank == 0, - ) - wrapper.init_worker(**kwargs) - - return wrapper.worker - - def _run_workers( - self, - method: str, - *args, - async_run_remote_workers_only: bool = False, - max_concurrent_workers: Optional[int] = None, - **kwargs, - ) -> Any: - """Runs the given method on all workers. - - Args: - async_run_remote_workers_only: If True the method will be run only - in the remote workers, not the driver worker. It will also be - run asynchronously and return a list of futures rather than - blocking on the results. - """ - - if max_concurrent_workers: - raise NotImplementedError( - "max_concurrent_workers is not supported yet.") - - # Start the workers first. - worker_outputs = [ - worker.execute_method(method, *args, **kwargs) - for worker in self.workers - ] - - if async_run_remote_workers_only: - # Just return futures - return worker_outputs - - driver_worker_output = self.driver_method_invoker( - self.driver_worker, method, *args, **kwargs) - - # Get the results of the workers. - return [driver_worker_output - ] + [output.get() for output in worker_outputs] - - def determine_num_available_blocks(self) -> Tuple[int, int]: - """Determine the number of available KV blocks by invoking the - underlying worker. - """ - return self.driver_method_invoker(self.driver_worker, - "determine_num_available_blocks") - - def initialize_cache(self, num_gpu_blocks: int, - num_cpu_blocks: int) -> None: - """Initialize the KV cache by invoking the underlying worker. - """ - # NOTE: We log here to avoid multiple logs when number of workers is - # greater than one. We could log in the engine, but not all executors - # have GPUs. - # NOTE: `cpu block` for CPU backend is located on CPU memory but is - # referred as `gpu block`. Because we want to reuse the existing block - # management procedure. - logger.info("# CPU blocks: %d", num_gpu_blocks) - - self._run_workers("initialize_cache", - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=num_cpu_blocks) - - def execute_model( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: - if (self.parallel_config.tensor_parallel_size > 1 - and self.parallel_worker_tasks is None): - self.parallel_worker_tasks = self._run_workers( - "start_worker_execution_loop", - async_run_remote_workers_only=True, - ) - output = self.driver_method_invoker(self.driver_worker, - "execute_model", execute_model_req) - return output - - def stop_remote_worker_execution_loop(self) -> None: - if self.parallel_worker_tasks is None: - return - """ - Passing None will cause the driver to stop the model execution - loop running in each of the remote workers. - """ - self.driver_method_invoker(self.driver_worker, "execute_model", None) - parallel_worker_tasks = self.parallel_worker_tasks - self.parallel_worker_tasks = None - # Ensure that workers exit model loop cleanly - # (this will raise otherwise) - self._wait_for_tasks_completion(parallel_worker_tasks) - - def add_lora(self, lora_request: LoRARequest) -> bool: - return all(self._run_workers("add_lora", lora_request)) - - def remove_lora(self, lora_id: int) -> bool: - return all(self._run_workers("remove_lora", lora_id)) - - def pin_lora(self, lora_id: int) -> bool: - assert lora_id > 0, "lora_id must be greater than 0." - return all(self._run_workers( - "pin_lora", - lora_id=lora_id, - )) - - def list_loras(self) -> Set[int]: - return self.driver_method_invoker(self.driver_worker, "list_loras") - - def add_prompt_adapter( - self, prompt_adapter_request: PromptAdapterRequest) -> bool: - return all( - self._run_workers( - "add_prompt_adapter", - prompt_adapter_request, - )) - - def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: - return all( - self._run_workers( - "remove_prompt_adapter", - prompt_adapter_id, - )) - - def list_prompt_adapters(self) -> Set[int]: - return self.driver_method_invoker(self.driver_worker, - "list_prompt_adapters") - - def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: - return all(self._run_workers( - "pin_prompt_adapter", - prompt_adapter_id, - )) - - def check_health(self) -> None: - """Raises an error if engine is unhealthy.""" - if self.worker_monitor is not None and not self.worker_monitor.is_alive( - ): - raise RuntimeError("Worker processes are not running") - - def shutdown(self): - if (worker_monitor := getattr(self, "worker_monitor", - None)) is not None: - worker_monitor.close() - - def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: - """Wait for futures returned from _run_workers() with - async_run_remote_workers_only to complete.""" - for result in parallel_worker_tasks: - result.get() - - def start_profile(self) -> None: - self.driver_method_invoker(self.driver_worker, "start_profile") - - def stop_profile(self) -> None: - self.driver_method_invoker(self.driver_worker, "stop_profile") - - -class CPUExecutorAsync(CPUExecutor, ExecutorAsyncBase): - - async def execute_model_async( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: - output = await make_async(self.execute_model - )(execute_model_req=execute_model_req, ) - return output - - async def check_health_async(self) -> None: - self.check_health() - - -def _driver_method_invoker(driver, method: str, *args, **kwargs): - return getattr(driver, method)(*args, **kwargs) - - -def _async_driver_method_invoker(driver, method: str, *args, **kwargs): - return driver.execute_method(method, *args, **kwargs).get() diff --git a/vllm/executor/distributed_gpu_executor.py b/vllm/executor/distributed_gpu_executor.py deleted file mode 100644 index deb7cb1c97ef5..0000000000000 --- a/vllm/executor/distributed_gpu_executor.py +++ /dev/null @@ -1,212 +0,0 @@ -import asyncio -from abc import abstractmethod -from typing import Any, Awaitable, Dict, List, Optional, Set, Tuple, Union - -from vllm.executor.executor_base import ExecutorAsyncBase -from vllm.executor.gpu_executor import GPUExecutor -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.sequence import ExecuteModelRequest - -logger = init_logger(__name__) - - -class DistributedGPUExecutor(GPUExecutor): - """Abstract superclass of multi-GPU executor implementations.""" - - def __init__(self, *args, **kwargs): - # This is non-None when the execute model loop is running - # in the parallel workers. It's a coroutine in the AsyncLLMEngine case. - self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None - # Updated by implementations that require additional args to be passed - # to the _run_workers execute_model call - self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {} - - super().__init__(*args, **kwargs) - - def determine_num_available_blocks(self) -> Tuple[int, int]: - """Determine the number of available KV blocks. - - This invokes `determine_num_available_blocks` on each worker and takes - the min of the results, guaranteeing that the selected cache sizes are - compatible with all workers. - - Returns: - - tuple[num_gpu_blocks, num_cpu_blocks] - """ - # Get the maximum number of blocks that can be allocated on GPU and CPU. - num_blocks = self._run_workers("determine_num_available_blocks", ) - - # Since we use a shared centralized controller, we take the minimum - # number of blocks across all workers to make sure all the memory - # operators can be applied to all workers. - num_gpu_blocks = min(b[0] for b in num_blocks) - num_cpu_blocks = min(b[1] for b in num_blocks) - - return num_gpu_blocks, num_cpu_blocks - - def initialize_cache(self, num_gpu_blocks: int, - num_cpu_blocks: int) -> None: - """Initialize the KV cache in all workers. - """ - - # NOTE: We log here to avoid multiple logs when number of workers is - # greater than one. We could log in the engine, but not all executors - # have GPUs. - logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks, - num_cpu_blocks) - max_concurrency = (num_gpu_blocks * self.cache_config.block_size / - self.model_config.max_model_len) - logger.info("Maximum concurrency for %s tokens per request: %.2fx", - self.model_config.max_model_len, max_concurrency) - - self.cache_config.num_gpu_blocks = num_gpu_blocks - self.cache_config.num_cpu_blocks = num_cpu_blocks - - self._run_workers("initialize_cache", - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=num_cpu_blocks) - - def execute_model( - self, - execute_model_req: ExecuteModelRequest, - ) -> List[SamplerOutput]: - if self.parallel_worker_tasks is None: - self.parallel_worker_tasks = self._run_workers( - "start_worker_execution_loop", - async_run_tensor_parallel_workers_only=True, - **self.extra_execute_model_run_workers_kwargs) - - # Only the driver worker returns the sampling results. - driver_outputs = self._driver_execute_model(execute_model_req) - assert driver_outputs is not None - return driver_outputs - - def stop_remote_worker_execution_loop(self) -> None: - if self.parallel_worker_tasks is None: - return - - self._driver_execute_model(execute_model_req=None) - parallel_worker_tasks = self.parallel_worker_tasks - self.parallel_worker_tasks = None - # Ensure that workers exit model loop cleanly - # (this will raise otherwise) - self._wait_for_tasks_completion(parallel_worker_tasks) - - def add_lora(self, lora_request: LoRARequest) -> bool: - assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." - return self._run_workers( - "add_lora", - lora_request=lora_request, - ) - - def remove_lora(self, lora_id: int) -> bool: - assert lora_id > 0, "lora_id must be greater than 0." - return self._run_workers( - "remove_lora", - lora_id=lora_id, - ) - - def pin_lora(self, lora_id: int) -> bool: - assert lora_id > 0, "lora_id must be greater than 0." - return self._run_workers( - "pin_lora", - lora_id=lora_id, - ) - - def list_loras(self) -> Set[int]: - return self._run_workers("list_loras") - - def save_sharded_state( - self, - path: str, - pattern: Optional[str] = None, - max_size: Optional[int] = None, - ) -> None: - self._run_workers("save_sharded_state", - path=path, - pattern=pattern, - max_size=max_size) - - @abstractmethod - def _driver_execute_model( - self, execute_model_req: Optional[ExecuteModelRequest] - ) -> Optional[List[SamplerOutput]]: - """Run execute_model in the driver worker. - - Passing None will cause the driver to stop the model execution loop - running in each of the remote workers. In this case, this method - returns None. Otherwise, this method returns the model output. - """ - raise NotImplementedError - - @abstractmethod - def _run_workers( - self, - method: str, - *args, - async_run_tensor_parallel_workers_only: bool = False, - max_concurrent_workers: Optional[int] = None, - **kwargs, - ) -> Any: - """Runs the given method on all workers. - - Args: - async_run_tensor_parallel_workers_only: If True the method will be - run only in the remote TP workers, not the driver worker. - It will also be run asynchronously and return a list of futures - rather than blocking on the results. - """ - raise NotImplementedError - - @abstractmethod - def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: - """Wait for futures returned from _run_workers() with - async_run_remote_workers_only to complete.""" - raise NotImplementedError - - -class DistributedGPUExecutorAsync(DistributedGPUExecutor, ExecutorAsyncBase): - - async def execute_model_async( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: - if self.parallel_worker_tasks is None: - # Start model execution loop running in the parallel workers - self.parallel_worker_tasks = asyncio.create_task( - self._start_worker_execution_loop()) - - # Only the driver worker returns the sampling results. - return await self._driver_execute_model_async(execute_model_req) - - async def stop_remote_worker_execution_loop_async(self) -> None: - if self.parallel_worker_tasks is None: - return - - await self._driver_execute_model_async() - parallel_worker_tasks = self.parallel_worker_tasks - self.parallel_worker_tasks = None - # Ensure that workers exit model loop cleanly - # (this will raise otherwise) - await parallel_worker_tasks - - @abstractmethod - async def _driver_execute_model_async( - self, - execute_model_req: Optional[ExecuteModelRequest] = None, - ) -> List[SamplerOutput]: - """Execute the model asynchronously in the driver worker. - - Passing None will cause the driver to stop the model execution - loop running in each of the remote workers. - """ - raise NotImplementedError - - @abstractmethod - async def _start_worker_execution_loop(self): - """Run execution loop on all workers. It guarantees all workers run - the loop or None of them is running the loop. Loop can be stopped by - `stop_remote_worker_execution_loop`. - The API is idempotent (guarantee only 1 loop run at any moment).""" - raise NotImplementedError diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 9cba189dd57f9..00ecadcf92667 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -1,18 +1,24 @@ +import asyncio from abc import ABC, abstractmethod -from typing import List, Optional, Set, Tuple +from typing import Any, Awaitable, Dict, List, Optional, Set, Tuple, Union from vllm.config import VllmConfig +from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.platforms import current_platform from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sequence import ExecuteModelRequest +from vllm.sequence import ExecuteModelRequest, PoolerOutput +from vllm.utils import make_async + +logger = init_logger(__name__) class ExecutorBase(ABC): """Base class for all executors. - An executor is responsible for executing the model on a specific device - type (e.g., CPU, GPU, Neuron, etc.). Or it can be a distributed executor + An executor is responsible for executing the model on one device, + or it can be a distributed executor that can execute the model on multiple devices. """ @@ -40,6 +46,20 @@ def _init_executor(self) -> None: pass @abstractmethod + def collective_rpc(self, + method: str, + timeout: Optional[float] = None, + args: Tuple = (), + kwargs: Optional[Dict] = None) -> List[Any]: + """ + The main interface of the executor to run a method on all workers, + with homogeneous arguments. + If the args are heterogeneous, then we can pack them into a list, + and unpack them in the method of every worker, because every worker + knows their own rank. + """ + pass + def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of available blocks for the GPU KV cache and swappable CPU KV cache. @@ -53,58 +73,113 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be appended to. """ - raise NotImplementedError + results = self.collective_rpc("determine_num_available_blocks") + a = min([r[0] for r in results]) + b = min([r[1] for r in results]) + return a, b - @abstractmethod - def initialize_cache(self, num_gpu_blocks: int, - num_cpu_blocks: int) -> None: - """Initialize the KV cache with the given size in blocks. + def initialize(self, num_gpu_blocks: int) -> None: """ - raise NotImplementedError + Initialize the KV caches and begin the model execution loop of the + underlying workers. + For V1 compatibility. + """ + logger.info("# GPU blocks: %d", num_gpu_blocks) + self.collective_rpc("initialize_cache", args=(num_gpu_blocks, )) + self.collective_rpc("compile_or_warm_up_model") + + def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None: + """Initialize the KV cache by invoking the underlying worker. + """ + # NOTE: This is logged in the executor because there can be >1 workers. + logger.info("# %s blocks: %d, # CPU blocks: %d", + current_platform.dispatch_key, num_gpu_blocks, + num_cpu_blocks) + max_concurrency = (num_gpu_blocks * self.cache_config.block_size / + self.model_config.max_model_len) + logger.info("Maximum concurrency for %s tokens per request: %.2fx", + self.model_config.max_model_len, max_concurrency) + + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + + self.collective_rpc("initialize_cache", + args=(num_gpu_blocks, num_cpu_blocks)) - @abstractmethod def execute_model( self, execute_model_req: ExecuteModelRequest - ) -> Optional[List[SamplerOutput]]: - """Executes at least one model step on the given sequences.""" - raise NotImplementedError + ) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]: + output = self.collective_rpc("execute_model", + args=(execute_model_req, )) + return output[0] def stop_remote_worker_execution_loop(self) -> None: """Releases parallel workers from model loop.""" return - @abstractmethod def add_lora(self, lora_request: LoRARequest) -> bool: - raise NotImplementedError + assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." + return all(self.collective_rpc("add_lora", args=(lora_request, ))) - @abstractmethod def remove_lora(self, lora_id: int) -> bool: - raise NotImplementedError + assert lora_id > 0, "lora_id must be greater than 0." + return all(self.collective_rpc("remove_lora", args=(lora_id, ))) - @abstractmethod def pin_lora(self, lora_id: int) -> bool: - raise NotImplementedError # type: ignore + assert lora_id > 0, "lora_id must be greater than 0." + return all(self.collective_rpc("pin_lora", args=(lora_id, ))) - @abstractmethod def list_loras(self) -> Set[int]: - raise NotImplementedError + sets = self.collective_rpc("list_loras") + for s in sets: + assert s == sets[0], "All workers should have the same LORAs." + return sets[0] - @abstractmethod def add_prompt_adapter( self, prompt_adapter_request: PromptAdapterRequest) -> bool: - raise NotImplementedError + assert prompt_adapter_request.prompt_adapter_id > 0, \ + "prompt_adapter_id must be greater than 0." + return all( + self.collective_rpc("add_prompt_adapter", + args=(prompt_adapter_request, ))) - @abstractmethod def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: - raise NotImplementedError + assert prompt_adapter_id > 0, \ + "prompt_adapter_id must be greater than 0." + return all( + self.collective_rpc("remove_prompt_adapter", + args=(prompt_adapter_id, ))) - @abstractmethod def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: - raise NotImplementedError # type: ignore + assert prompt_adapter_id > 0, \ + "prompt_adapter_id must be greater than 0." + return all( + self.collective_rpc("pin_prompt_adapter", + args=(prompt_adapter_id, ))) - @abstractmethod def list_prompt_adapters(self) -> Set[int]: - raise NotImplementedError + sets = self.collective_rpc("list_prompt_adapters") + for s in sets: + assert (s == sets[0] + ), "All workers should have the same prompt adapters." + return sets[0] + + def start_profile(self) -> None: + self.collective_rpc("start_profile") + + def stop_profile(self) -> None: + self.collective_rpc("stop_profile") + + def save_sharded_state( + self, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + self.collective_rpc("save_sharded_state", + kwargs=dict(path=path, + pattern=pattern, + max_size=max_size)) @abstractmethod def check_health(self) -> None: @@ -119,15 +194,12 @@ def shutdown(self) -> None: def __del__(self): self.shutdown() - -class ExecutorAsyncBase(ExecutorBase): - - @abstractmethod async def execute_model_async( self, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: """Executes one model step on the given sequences.""" - raise NotImplementedError + output = await make_async(self.execute_model)(execute_model_req) + return output async def stop_remote_worker_execution_loop_async(self) -> None: """Releases parallel workers from model loop.""" @@ -137,3 +209,128 @@ async def check_health_async(self) -> None: """Checks if the executor is healthy. If not, it should raise an exception.""" self.check_health() + + +class DistributedExecutorBase(ExecutorBase): + """Abstract superclass of distributed executor implementations.""" + + def __init__(self, *args, **kwargs): + # This is non-None when the execute model loop is running + # in the parallel workers. It's a coroutine in the AsyncLLMEngine case. + self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None + + super().__init__(*args, **kwargs) + + def execute_model( + self, + execute_model_req: ExecuteModelRequest, + ) -> List[SamplerOutput]: + # TODO: unify into collective_rpc + if self.parallel_worker_tasks is None: + self.parallel_worker_tasks = self._run_workers( + "start_worker_execution_loop", + async_run_tensor_parallel_workers_only=True) + + # Only the driver worker returns the sampling results. + driver_outputs = self._driver_execute_model(execute_model_req) + assert driver_outputs is not None + return driver_outputs + + def stop_remote_worker_execution_loop(self) -> None: + if self.parallel_worker_tasks is None: + return + + self._driver_execute_model(execute_model_req=None) + parallel_worker_tasks = self.parallel_worker_tasks + self.parallel_worker_tasks = None + # Ensure that workers exit model loop cleanly + # (this will raise otherwise) + self._wait_for_tasks_completion(parallel_worker_tasks) + + @abstractmethod + def _driver_execute_model( + self, execute_model_req: Optional[ExecuteModelRequest] + ) -> Optional[List[SamplerOutput]]: + """Run execute_model in the driver worker. + + Passing None will cause the driver to stop the model execution loop + running in each of the remote workers. In this case, this method + returns None. Otherwise, this method returns the model output. + """ + raise NotImplementedError + + def collective_rpc(self, + method: str, + timeout: Optional[float] = None, + args: Tuple = (), + kwargs: Optional[Dict] = None) -> List[Any]: + return self._run_workers(method, *args, **(kwargs or {})) + + @abstractmethod + def _run_workers( + self, + method: str, + *args, + async_run_tensor_parallel_workers_only: bool = False, + max_concurrent_workers: Optional[int] = None, + **kwargs, + ) -> Any: + """Runs the given method on all workers. + + Args: + async_run_tensor_parallel_workers_only: If True the method will be + run only in the remote TP workers, not the driver worker. + It will also be run asynchronously and return a list of futures + rather than blocking on the results. + + # TODO: simplify and merge with collective_rpc + """ + raise NotImplementedError + + @abstractmethod + def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: + """Wait for futures returned from _run_workers() with + async_run_remote_workers_only to complete.""" + raise NotImplementedError + + async def execute_model_async( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + if self.parallel_worker_tasks is None: + # Start model execution loop running in the parallel workers + self.parallel_worker_tasks = asyncio.create_task( + self._start_worker_execution_loop()) + + # Only the driver worker returns the sampling results. + return await self._driver_execute_model_async(execute_model_req) + + async def stop_remote_worker_execution_loop_async(self) -> None: + if self.parallel_worker_tasks is None: + return + + await self._driver_execute_model_async() + parallel_worker_tasks = self.parallel_worker_tasks + self.parallel_worker_tasks = None + # Ensure that workers exit model loop cleanly + # (this will raise otherwise) + await parallel_worker_tasks + + @abstractmethod + async def _driver_execute_model_async( + self, + execute_model_req: Optional[ExecuteModelRequest] = None, + ) -> List[SamplerOutput]: + """Execute the model asynchronously in the driver worker. + + Passing None will cause the driver to stop the model execution + loop running in each of the remote workers. + """ + raise NotImplementedError + + @abstractmethod + async def _start_worker_execution_loop(self): + """Run execution loop on all workers. It guarantees all workers run + the loop or None of them is running the loop. Loop can be stopped by + `stop_remote_worker_execution_loop`. + The API is idempotent (guarantee only 1 loop run at any moment).""" + raise NotImplementedError diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py deleted file mode 100644 index 7fa34456028dd..0000000000000 --- a/vllm/executor/gpu_executor.py +++ /dev/null @@ -1,145 +0,0 @@ -from typing import Any, Dict, List, Optional, Set, Tuple, Union - -from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sequence import ExecuteModelRequest, PoolerOutput -from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, - make_async) -from vllm.worker.worker_base import WorkerWrapperBase - -logger = init_logger(__name__) - - -def create_worker(**kwargs): - vllm_config = kwargs.get("vllm_config") - wrapper = WorkerWrapperBase(vllm_config=vllm_config) - wrapper.init_worker(**kwargs) - return wrapper.worker - - -class GPUExecutor(ExecutorBase): - - uses_ray: bool = False - - def _init_executor(self) -> None: - """Initialize the worker and load the model. - """ - assert self.parallel_config.world_size == 1, ( - "GPUExecutor only supports single GPU.") - - self.driver_worker = self._create_worker() - self.driver_worker.init_device() - self.driver_worker.load_model() - - def _get_worker_kwargs( - self, - local_rank: int = 0, - rank: int = 0, - distributed_init_method: Optional[str] = None) -> Dict[str, Any]: - """Return worker init args for a given rank.""" - if distributed_init_method is None: - distributed_init_method = get_distributed_init_method( - get_ip(), get_open_port()) - return dict( - vllm_config=self.vllm_config, - local_rank=local_rank, - rank=rank, - distributed_init_method=distributed_init_method, - is_driver_worker=(not self.parallel_config) - or (rank % self.parallel_config.tensor_parallel_size == 0), - ) - - def _create_worker(self, - local_rank: int = 0, - rank: int = 0, - distributed_init_method: Optional[str] = None): - return create_worker(**self._get_worker_kwargs( - local_rank=local_rank, - rank=rank, - distributed_init_method=distributed_init_method)) - - def determine_num_available_blocks(self) -> Tuple[int, int]: - """Determine the number of available KV blocks by invoking the - underlying worker. - """ - return self.driver_worker.determine_num_available_blocks() - - def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None: - """Initialize the KV cache by invoking the underlying worker. - """ - # NOTE: This is logged in the executor because there can be >1 worker - # with other executors. We could log in the engine level, but work - # remains to abstract away the device for non-GPU configurations. - logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks, - num_cpu_blocks) - max_concurrency = (num_gpu_blocks * self.cache_config.block_size / - self.model_config.max_model_len) - logger.info("Maximum concurrency for %s tokens per request: %.2fx", - self.model_config.max_model_len, max_concurrency) - - self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) - - def execute_model( - self, execute_model_req: ExecuteModelRequest - ) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]: - output = self.driver_worker.execute_model(execute_model_req) - return output - - def add_lora(self, lora_request: LoRARequest) -> bool: - assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." - return self.driver_worker.add_lora(lora_request) - - def remove_lora(self, lora_id: int) -> bool: - assert lora_id > 0, "lora_id must be greater than 0." - return self.driver_worker.remove_lora(lora_id) - - def pin_lora(self, lora_id: int) -> bool: - assert lora_id > 0, "lora_id must be greater than 0." - return self.driver_worker.pin_lora(lora_id) - - def list_loras(self) -> Set[int]: - return self.driver_worker.list_loras() - - def add_prompt_adapter( - self, prompt_adapter_request: PromptAdapterRequest) -> bool: - assert prompt_adapter_request.prompt_adapter_id > 0, \ - "prompt_adapter_id must be greater than 0." - return self.driver_worker.add_prompt_adapter(prompt_adapter_request) - - def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: - assert prompt_adapter_id > 0, \ - "prompt_adapter_id must be greater than 0." - return self.driver_worker.remove_prompt_adapter(prompt_adapter_id) - - def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: - assert prompt_adapter_id > 0, \ - "prompt_adapter_id must be greater than 0." - return self.driver_worker.pin_prompt_adapter(prompt_adapter_id) - - def list_prompt_adapters(self) -> Set[int]: - return self.driver_worker.list_prompt_adapters() - - def check_health(self) -> None: - # GPUExecutor will always be healthy as long as - # it's running. - return - - def start_profile(self) -> None: - self.driver_worker.start_profile() - - def stop_profile(self) -> None: - self.driver_worker.stop_profile() - - -class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase): - - async def execute_model_async( - self, - execute_model_req: ExecuteModelRequest, - ) -> List[Union[SamplerOutput, PoolerOutput]]: - output = await make_async(self.driver_worker.execute_model - )(execute_model_req=execute_model_req) - return output diff --git a/vllm/executor/hpu_executor.py b/vllm/executor/hpu_executor.py deleted file mode 100644 index c9b7bfa71edfa..0000000000000 --- a/vllm/executor/hpu_executor.py +++ /dev/null @@ -1,202 +0,0 @@ -############################################################################### -# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company -############################################################################### - -import contextlib -import os -from typing import Any, Dict, List, Optional, Set, Tuple - -from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sequence import ExecuteModelRequest -from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, - make_async) -from vllm.worker.worker_base import WorkerWrapperBase - -logger = init_logger(__name__) - - -class HPUExecutor(ExecutorBase): - - uses_ray: bool = False - - def _init_executor(self) -> None: - """Initialize the worker and load the model.""" - self._init_worker() - - def _get_worker_kwargs( - self, - local_rank: int = 0, - rank: int = 0, - distributed_init_method: Optional[str] = None) -> Dict[str, Any]: - """Return worker init args for a given rank.""" - if distributed_init_method is None: - distributed_init_method = get_distributed_init_method( - get_ip(), get_open_port()) - return dict( - vllm_config=self.vllm_config, - local_rank=local_rank, - rank=rank, - distributed_init_method=distributed_init_method, - is_driver_worker=rank == 0, - ) - - def _create_worker(self, - local_rank: int = 0, - rank: int = 0, - distributed_init_method: Optional[str] = None): - wrapper = WorkerWrapperBase(vllm_config=self.vllm_config) - wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank, - distributed_init_method)) - return wrapper.worker - - def _init_worker(self): - assert self.parallel_config.world_size == 1, ( - "GPUExecutor only supports single GPU.") - - self.driver_worker = self._create_worker() - self.driver_worker.init_device() - self.driver_worker.load_model() - - def determine_num_available_blocks(self) -> Tuple[int, int]: - """Determine the number of available KV blocks by invoking the - underlying worker. - """ - return self.driver_worker.determine_num_available_blocks() - - def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None: - """Initialize the KV cache by invoking the underlying worker. - """ - # NOTE: This is logged in the executor because there can be >1 worker - # with other executors. We could log in the engine level, but work - # remains to abstract away the device for non-GPU configurations. - logger.info("# HPU blocks: %d, # CPU blocks: %d", num_gpu_blocks, - num_cpu_blocks) - from vllm_hpu_extension.profiler import HabanaMemoryProfiler - with HabanaMemoryProfiler() as cache_init_m: - self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) - msg = f"init_cache_engine took {cache_init_m.get_summary_string()}" - logger.info(msg) - - def finish_measurements(self): - self.driver_worker.finish_measurements() - - def execute_model( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: - # VLLM_HPU_LOG_STEP_GRAPH_COMPILATION - will log graph compilations per engine step, only when there was any - highly recommended to use alongside PT_HPU_METRICS_GC_DETAILS! # noqa:E501 - # VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL - will log graph compilations per engine step, always, even if there were none # noqa:E501 - # VLLM_HPU_LOG_STEP_CPU_FALLBACKS - will log cpu fallbacks per engine step, only when there was any # noqa:E501 - # VLLM_HPU_LOG_STEP_CPU_FALLBACKS_ALL - will log cpu fallbacks per engine step, always, even if there were none # noqa:E501 - log_graph_compilation_all = os.environ.get( - 'VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL', '0') != '0' - log_graph_compilation = os.environ.get( - 'VLLM_HPU_LOG_STEP_GRAPH_COMPILATION', - '0') != '0' or log_graph_compilation_all - log_cpu_fallbacks_all = os.environ.get( - 'VLLM_HPU_LOG_STEP_CPU_FALLBACKS_ALL', '0') != '0' - log_cpu_fallbacks = os.environ.get('VLLM_HPU_LOG_STEP_CPU_FALLBACKS', - '0') != '0' or log_cpu_fallbacks_all - if log_graph_compilation or log_cpu_fallbacks: - from habana_frameworks.torch.hpu.metrics import metric_localcontext - seq_group_metadata_list = execute_model_req.seq_group_metadata_list - is_prompt = any([ - seq_group_metadata.is_prompt - for seq_group_metadata in seq_group_metadata_list - ]) - max_context_len = max([ - max([ - len(v.prompt_token_ids) + len(v.output_token_ids) - for v in seq_group_metadata.seq_data.values() - ]) for seq_group_metadata in seq_group_metadata_list - ]) # whoa, that's some spicy stuff right here - max_num_blocks = ( - (max_context_len - 1) // self.cache_config.block_size) + 1 - input_stats = (f'is_prompt: {is_prompt}, ' - f'num_seqs: {len(seq_group_metadata_list)}, ' - f'max_context_len: {max_context_len}, ' - f'max_num_blocks {max_num_blocks}') - gc_ctx = metric_localcontext( - "graph_compilation" - ) if log_graph_compilation else contextlib.nullcontext() - cpu_fallback_ctx = metric_localcontext( - "cpu_fallback" - ) if log_cpu_fallbacks else contextlib.nullcontext() - with gc_ctx as gc_local_metric, \ - cpu_fallback_ctx as cpu_fallback_local_metric: - output = self.driver_worker.execute_model(execute_model_req) - if (log_graph_compilation and gc_local_metric.stats()[0][1] > 0 - ) or log_graph_compilation_all: - msg = ("VLLM_HPU_STEP_GRAPH_COMPILATION: " - f"{gc_local_metric.stats()}, {input_stats}") - logger.warning(msg) - if (log_cpu_fallbacks and cpu_fallback_local_metric.stats()[0][1] > - 0) or log_cpu_fallbacks_all: - msg = ("VLLM_HPU_STEP_CPU_FALLBACK: " - f"{cpu_fallback_local_metric.stats()}, {input_stats}") - logger.warning(msg) - - return output - - output = self.driver_worker.execute_model(execute_model_req) - return output - - def add_lora(self, lora_request: LoRARequest) -> bool: - assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." - return self.driver_worker.add_lora(lora_request) - - def remove_lora(self, lora_id: int) -> bool: - assert lora_id > 0, "lora_id must be greater than 0." - return self.driver_worker.remove_lora(lora_id) - - def pin_lora(self, lora_id: int) -> bool: - assert lora_id > 0, "lora_id must be greater than 0." - return self.driver_worker.pin_lora(lora_id) - - def list_loras(self) -> Set[int]: - return self.driver_worker.list_loras() - - def add_prompt_adapter( - self, prompt_adapter_request: PromptAdapterRequest) -> bool: - raise NotImplementedError( - "Prompt Adapter is not implemented for HPU backend.") - - def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: - raise NotImplementedError( - "Prompt Adapter is not implemented for HPU backend.") - - def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: - raise NotImplementedError( - "Prompt Adapter is not implemented for HPU backend.") - - def list_prompt_adapters(self) -> Set[int]: - raise NotImplementedError( - "Prompt Adapter is not implemented for HPU backend.") - - def check_health(self) -> None: - # GPUExecutor will always be healthy as long as - # it's running. - return - - def start_profile(self) -> None: - self.driver_worker.start_profile() - - def stop_profile(self) -> None: - self.driver_worker.stop_profile() - - def shutdown(self) -> None: - self.driver_worker.shutdown_inc() - - -class HPUExecutorAsync(HPUExecutor, ExecutorAsyncBase): - - async def execute_model_async( - self, - execute_model_req: ExecuteModelRequest, - ) -> List[SamplerOutput]: - output = await make_async(self.driver_worker.execute_model - )(execute_model_req=execute_model_req, ) - return output diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/mp_distributed_executor.py similarity index 75% rename from vllm/executor/multiproc_gpu_executor.py rename to vllm/executor/mp_distributed_executor.py index fc58163cade64..d9dde949b844a 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/mp_distributed_executor.py @@ -1,32 +1,26 @@ import asyncio -import os -from functools import partial from typing import Any, List, Optional -from vllm.executor.distributed_gpu_executor import ( # yapf: disable - DistributedGPUExecutor, DistributedGPUExecutorAsync) -from vllm.executor.gpu_executor import create_worker +from vllm.executor.executor_base import DistributedExecutorBase from vllm.executor.multiproc_worker_utils import ( ProcessWorkerWrapper, ResultHandler, WorkerMonitor, set_multiprocessing_worker_envs) from vllm.logger import init_logger from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import ExecuteModelRequest -from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless, - get_distributed_init_method, get_open_port, make_async, - update_environment_variables) +from vllm.utils import (_run_task_with_lock, get_distributed_init_method, + get_ip, get_open_port, make_async) +from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) -class MultiprocessingGPUExecutor(DistributedGPUExecutor): - """Python multiprocessing-based multi-GPU executor""" +class MultiprocessingDistributedExecutor(DistributedExecutorBase): + """Python multiprocessing-based distributed executor""" uses_ray: bool = False def _init_executor(self) -> None: - self._check_executor_parameters() - # Create the parallel GPU workers. world_size = self.parallel_config.world_size tensor_parallel_size = self.parallel_config.tensor_parallel_size @@ -55,15 +49,9 @@ def _init_executor(self) -> None: else: result_handler = ResultHandler() for rank in range(1, world_size): - worker = ProcessWorkerWrapper( - result_handler, - partial( - create_worker, - **self._get_worker_kwargs( - rank=rank, - local_rank=rank, - distributed_init_method=distributed_init_method, - ))) + worker = ProcessWorkerWrapper(result_handler, + WorkerWrapperBase, + self.vllm_config, rank) self.workers.append(worker) if rank % tensor_parallel_size == 0: self.tp_driver_workers.append(worker) @@ -77,32 +65,30 @@ def _init_executor(self) -> None: # Set up signal handlers to shutdown the executor cleanly # sometimes gc does not work well - self.driver_worker = self._create_worker( - distributed_init_method=distributed_init_method) + self.driver_worker = WorkerWrapperBase(self.vllm_config, 0) + + all_kwargs = [] + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + for i in range(world_size): + local_rank = i + rank = i + kwargs = dict( + vllm_config=self.vllm_config, + local_rank=local_rank, + rank=rank, + distributed_init_method=distributed_init_method, + is_driver_worker=(not self.parallel_config) + or (rank % self.parallel_config.tensor_parallel_size == 0), + ) + all_kwargs.append(kwargs) + self._run_workers("init_worker", all_kwargs) self._run_workers("init_device") self._run_workers("load_model", max_concurrent_workers=self.parallel_config. max_parallel_loading_workers) - - def _check_executor_parameters(self): - world_size = self.parallel_config.world_size - tensor_parallel_size = self.parallel_config.tensor_parallel_size - - # Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers - if "CUDA_VISIBLE_DEVICES" not in os.environ: - update_environment_variables({ - "CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size)))) - }) - - cuda_device_count = cuda_device_count_stateless() - # Use confusing message for more common TP-only case. - assert tensor_parallel_size <= cuda_device_count, ( - f"please set tensor_parallel_size ({tensor_parallel_size}) " - f"to less than max local gpu count ({cuda_device_count})") - - assert world_size <= cuda_device_count, ( - f"please ensure that world_size ({world_size}) " - f"is less than than max local gpu count ({cuda_device_count})") + self.driver_exec_model = make_async(self.driver_worker.execute_model) + self.pp_locks: Optional[List[asyncio.Lock]] = None def shutdown(self): if (worker_monitor := getattr(self, "worker_monitor", @@ -172,15 +158,6 @@ def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: for result in parallel_worker_tasks: result.get() - -class MultiprocessingGPUExecutorAsync(MultiprocessingGPUExecutor, - DistributedGPUExecutorAsync): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.driver_exec_model = make_async(self.driver_worker.execute_model) - self.pp_locks: Optional[List[asyncio.Lock]] = None - async def _driver_execute_model_async( self, execute_model_req: Optional[ExecuteModelRequest] = None diff --git a/vllm/executor/multiproc_worker_utils.py b/vllm/executor/multiproc_worker_utils.py index bc32826529eef..c9fb3c664c575 100644 --- a/vllm/executor/multiproc_worker_utils.py +++ b/vllm/executor/multiproc_worker_utils.py @@ -12,6 +12,7 @@ import torch +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.triton_utils.importing import HAS_TRITON from vllm.utils import _check_multiproc_method, get_mp_context @@ -147,7 +148,8 @@ class ProcessWorkerWrapper: for handling single-node multi-GPU tensor parallel.""" def __init__(self, result_handler: ResultHandler, - worker_factory: Callable[[], Any]) -> None: + worker_factory: Callable[[VllmConfig, int], Any], + vllm_config: VllmConfig, rank: int) -> None: self.mp = get_mp_context() self._task_queue = self.mp.Queue() self.result_queue = result_handler.result_queue @@ -159,6 +161,8 @@ def __init__(self, result_handler: ResultHandler, worker_factory=worker_factory, task_queue=self._task_queue, result_queue=self.result_queue, + vllm_config=vllm_config, + rank=rank, ), daemon=True) @@ -199,9 +203,11 @@ def kill_worker(self): def _run_worker_process( - worker_factory: Callable[[], Any], + worker_factory: Callable[[VllmConfig, int], Any], task_queue: Queue, result_queue: Queue, + vllm_config: VllmConfig, + rank: int, ) -> None: """Worker process event loop""" @@ -212,7 +218,7 @@ def _run_worker_process( _add_prefix(sys.stderr, process_name, pid) # Initialize worker - worker = worker_factory() + worker = worker_factory(vllm_config, rank) del worker_factory # Accept tasks from the engine in task_queue diff --git a/vllm/executor/multiproc_xpu_executor.py b/vllm/executor/multiproc_xpu_executor.py deleted file mode 100644 index a66afbf939ef0..0000000000000 --- a/vllm/executor/multiproc_xpu_executor.py +++ /dev/null @@ -1,26 +0,0 @@ -import vllm.envs as envs -from vllm.executor.multiproc_gpu_executor import ( - MultiprocessingGPUExecutor, MultiprocessingGPUExecutorAsync) -from vllm.executor.xpu_executor import XPUExecutor -from vllm.logger import init_logger -from vllm.utils import make_async - -logger = init_logger(__name__) - - -class MultiprocessingXPUExecutor(MultiprocessingGPUExecutor, XPUExecutor): - """Python multiprocessing-based multi-XPU executor""" - - def _check_executor_parameters(self): - mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD - if mp_method != "spawn": - raise RuntimeError( - "XPU multiprocess executor only support spawn as mp method") - - -class MultiprocessingXPUExecutorAsync(MultiprocessingXPUExecutor, - MultiprocessingGPUExecutorAsync): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.driver_exec_model = make_async(self.driver_worker.execute_model) diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py deleted file mode 100644 index a9efc4f9a801c..0000000000000 --- a/vllm/executor/neuron_executor.py +++ /dev/null @@ -1,114 +0,0 @@ -from typing import List, Set, Tuple - -from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.sequence import ExecuteModelRequest -from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, - make_async) -from vllm.worker.worker_base import WorkerWrapperBase - -logger = init_logger(__name__) - - -class NeuronExecutor(ExecutorBase): - - uses_ray: bool = False - - def _init_executor(self) -> None: - assert (self.lora_config is - None), "LoRA is not supported for Neuron backend." - assert (not self.speculative_config - ), "Speculative decoding not yet supported for Neuron backend." - - # Instantiate the worker and load the model to the device. - self._init_worker() - - def _init_worker(self): - wrapper = WorkerWrapperBase(vllm_config=self.vllm_config) - distributed_init_method = get_distributed_init_method( - get_ip(), get_open_port()) - wrapper.init_worker( - vllm_config=self.vllm_config, - local_rank=0, - rank=0, - distributed_init_method=distributed_init_method, - ) - self.driver_worker = wrapper.worker - self.driver_worker.init_device() - self.driver_worker.load_model() - - def determine_num_available_blocks(self) -> Tuple[int, int]: - """Determine the number of available KV blocks by invoking the - underlying worker. - """ - return self.driver_worker.determine_num_available_blocks() - - def initialize_cache(self, num_gpu_blocks: int, - num_cpu_blocks: int) -> None: - """Initialize the KV cache by invoking the underlying worker. - """ - self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) - - def execute_model( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: - assert (not execute_model_req.blocks_to_swap_in - and not execute_model_req.blocks_to_swap_out - and not execute_model_req.blocks_to_copy), ( - "Cache operations are not supported for Neuron backend.") - assert execute_model_req.num_lookahead_slots == 0, ( - "lookahead not supported for Neuron backend.") - - output = self.driver_worker.execute_model(execute_model_req) - return output - - def add_lora(self, lora_request: LoRARequest) -> bool: - return self.driver_worker.add_lora(lora_request) - - def remove_lora(self, lora_id: int) -> bool: - return self.driver_worker.remove_lora(lora_id) - - def pin_lora(self, lora_id: int) -> bool: - return self.driver_worker.pin_lora(lora_id) - - def list_loras(self) -> Set[int]: - return self.driver_worker.list_loras() - - def add_prompt_adapter(self, prompt_adapter_request) -> bool: - raise NotImplementedError( - "Soft prompt is currently not supported by the Neuron backend.") - - def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: - raise NotImplementedError( - "Soft prompt is currently not supported by the Neuron backend.") - - def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: - raise NotImplementedError( - "Soft prompt is currently not supported by the Neuron backend.") - - def list_prompt_adapters(self) -> Set[int]: - raise NotImplementedError( - "Soft prompt is currently not supported by the Neuron backend.") - - def check_health(self) -> None: - # NeuronExecutor will always be healthy as long as - # it's running. - return - - -class NeuronExecutorAsync(NeuronExecutor, ExecutorAsyncBase): - - async def execute_model_async( - self, - execute_model_req: ExecuteModelRequest, - ) -> List[SamplerOutput]: - output = await make_async(self.driver_worker.execute_model - )(execute_model_req=execute_model_req, ) - return output - - async def check_health_async(self) -> None: - # NeuronExecutor will always be healthy as long as - # it's running. - return diff --git a/vllm/executor/openvino_executor.py b/vllm/executor/openvino_executor.py deleted file mode 100644 index 057a32364e512..0000000000000 --- a/vllm/executor/openvino_executor.py +++ /dev/null @@ -1,125 +0,0 @@ -from typing import List, Set, Tuple - -import openvino as ov - -import vllm.envs as envs -from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.platforms import current_platform -from vllm.sequence import ExecuteModelRequest -from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, - make_async) -from vllm.worker.worker_base import WorkerWrapperBase - -logger = init_logger(__name__) - - -class OpenVINOExecutor(ExecutorBase): - - uses_ray: bool = False - - def _init_executor(self) -> None: - assert self.device_config.device_type == "openvino" - assert self.lora_config is None, "OpenVINO backend doesn't support LoRA" - assert current_platform.is_openvino_cpu() or \ - current_platform.is_openvino_gpu(), \ - "OpenVINO backend supports only CPU and GPU devices" - - # Instantiate the worker and load the model to CPU. - self._init_worker() - - def _init_worker(self): - - wrapper = WorkerWrapperBase(vllm_config=self.vllm_config) - - distributed_init_method = get_distributed_init_method( - get_ip(), get_open_port()) - wrapper.init_worker( - ov_core=ov.Core(), - vllm_config=self.vllm_config, - local_rank=0, - rank=0, - distributed_init_method=distributed_init_method, - kv_cache_dtype=self.cache_config.cache_dtype, - is_driver_worker=True, - ) - self.driver_worker = wrapper.worker - self.driver_worker.init_device() - self.driver_worker.load_model() - - def determine_num_available_blocks(self) -> Tuple[int, int]: - """Determine the number of available KV blocks by invoking the - underlying worker. - """ - return self.driver_worker.determine_num_available_blocks() - - def initialize_cache(self, num_gpu_blocks: int, - num_cpu_blocks: int) -> None: - """Initialize the KV cache by invoking the underlying worker.""" - # NOTE: We log here to avoid multiple logs when number of workers is - # greater than one. We could log in the engine, but not all executors - # have GPUs. - # NOTE: In case of a CPU device, `cpu block` for OpenVINO backend - # is located on CPU memory but is referred as `gpu block`. - # Because we want to reuse the existing block management procedure. - device_blocks = num_gpu_blocks - swap_blocks = num_cpu_blocks - logger.info("OpenVINO %s: # device blocks: %d; # swap blocks: %d", - envs.VLLM_OPENVINO_DEVICE, device_blocks, swap_blocks) - self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) - - def execute_model( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: - output = self.driver_worker.execute_model(execute_model_req) - return output - - def add_lora(self, lora_request: LoRARequest) -> bool: - return self.driver_worker.add_lora(lora_request) - - def remove_lora(self, lora_id: int) -> bool: - return self.driver_worker.remove_lora(lora_id) - - def pin_lora(self, lora_id: int) -> bool: - return self.driver_worker.pin_lora(lora_id) - - def list_loras(self) -> Set[int]: - return self.driver_worker.list_loras() - - def add_prompt_adapter(self, prompt_adapter_request) -> bool: - raise NotImplementedError( - "Soft prompt is currently not supported by the OPENVINO backend.") - - def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: - raise NotImplementedError( - "Soft prompt is currently not supported by the OPENVINO backend.") - - def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: - raise NotImplementedError( - "Soft prompt is currently not supported by the OPENVINO backend.") - - def list_prompt_adapters(self) -> Set[int]: - raise NotImplementedError( - "Soft prompt is currently not supported by the OPENVINO backend.") - - def check_health(self) -> None: - # OpenVINOExecutor will always be healthy as long as - # it's running. - return - - -class OpenVINOExecutorAsync(OpenVINOExecutor, ExecutorAsyncBase): - - async def execute_model_async( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: - output = await make_async(self.driver_worker.execute_model - )(execute_model_req=execute_model_req, ) - return output - - async def check_health_async(self) -> None: - # OpenVINOExecutor will always be healthy as long as - # it's running. - return diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_distributed_executor.py similarity index 78% rename from vllm/executor/ray_gpu_executor.py rename to vllm/executor/ray_distributed_executor.py index e2c549cbd5331..edceece4b68dc 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_distributed_executor.py @@ -1,24 +1,29 @@ import asyncio import os from collections import defaultdict -from itertools import islice, repeat -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, List, Optional import msgspec import vllm.envs as envs -from vllm.executor.distributed_gpu_executor import ( # yapf: disable - DistributedGPUExecutor, DistributedGPUExecutorAsync) +from vllm.executor.executor_base import ( + DistributedExecutorBase) # yapf: disable from vllm.executor.msgspec_utils import encode_hook -from vllm.executor.ray_utils import RayWorkerWrapper, ray +from vllm.executor.ray_utils import (RayWorkerWrapper, initialize_ray_cluster, + ray) from vllm.logger import init_logger from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.platforms import current_platform from vllm.sequence import ExecuteModelRequest from vllm.utils import (_run_task_with_lock, get_distributed_init_method, get_ip, get_open_port, make_async) if ray is not None: + from ray.actor import ActorHandle from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy +else: + ActorHandle = None if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup @@ -26,12 +31,29 @@ logger = init_logger(__name__) -class RayGPUExecutor(DistributedGPUExecutor): +@dataclass +class RayWorkerMetaData: + """ + Metadata for a Ray worker. + The order of ray worker creation can be random, + and we need to reset the rank after creating all workers. + """ + worker: ActorHandle + created_rank: int + adjusted_rank: int = -1 + ip: str = "" + + +class RayDistributedExecutor(DistributedExecutorBase): uses_ray: bool = True def _init_executor(self) -> None: self.forward_dag: Optional[ray.dag.CompiledDAG] = None + if envs.VLLM_USE_V1: + # v1 always uses the compiled DAG and SPMD worker. + os.environ["VLLM_USE_RAY_SPMD_WORKER"] = "1" + os.environ["VLLM_USE_RAY_COMPILED_DAG"] = "1" # If the env var is set, it uses the Ray's compiled DAG API # which optimizes the control plane overhead. # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. @@ -53,6 +75,7 @@ def _init_executor(self) -> None: "VLLM_USE_RAY_COMPILED_DAG=1") assert self.uses_ray + initialize_ray_cluster(self.parallel_config) placement_group = self.parallel_config.placement_group # Disable Ray usage stats collection. @@ -66,6 +89,13 @@ def _init_executor(self) -> None: self.input_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook) self.output_decoder = msgspec.msgpack.Decoder( Optional[List[SamplerOutput]]) + self.use_v1 = envs.VLLM_USE_V1 + + self.pp_locks: Optional[List[asyncio.Lock]] = None + self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER + if not self.use_ray_compiled_dag: + self.driver_exec_method = make_async( + self.driver_worker.execute_method) def shutdown(self) -> None: if hasattr(self, "forward_dag") and self.forward_dag is not None: @@ -123,9 +153,10 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", # Create the workers. driver_ip = get_ip() - workers = [] + rank = 0 + worker_metadata: List[RayWorkerMetaData] = [] for bundle_id, bundle in enumerate(placement_group.bundle_specs): - if not bundle.get("GPU", 0): + if not bundle.get(current_platform.ray_device_key, 0): continue scheduling_strategy = PlacementGroupSchedulingStrategy( placement_group=placement_group, @@ -133,38 +164,51 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", placement_group_bundle_index=bundle_id, ) - worker = ray.remote( - num_cpus=0, - num_gpus=num_gpus, - scheduling_strategy=scheduling_strategy, - **ray_remote_kwargs, - )(RayWorkerWrapper).remote(vllm_config=self.vllm_config) - workers.append(worker) - - worker_ip_refs = [ - worker.get_node_ip.remote() # type: ignore[attr-defined] - for worker in workers - ] - worker_ips = ray.get(worker_ip_refs) + if current_platform.ray_device_key == "GPU": + # NV+AMD GPUs, and Intel XPUs + worker = ray.remote( + num_cpus=0, + num_gpus=num_gpus, + scheduling_strategy=scheduling_strategy, + **ray_remote_kwargs, + )(RayWorkerWrapper).remote(vllm_config=self.vllm_config, + rank=rank) + else: + worker = ray.remote( + num_cpus=0, + num_gpus=0, + resources={current_platform.ray_device_key: num_gpus}, + scheduling_strategy=scheduling_strategy, + **ray_remote_kwargs, + )(RayWorkerWrapper).remote(vllm_config=self.vllm_config, + rank=rank) + worker_metadata.append( + RayWorkerMetaData(worker=worker, created_rank=rank)) + rank += 1 + + worker_ips = ray.get([ + each.worker.get_node_ip.remote() # type: ignore[attr-defined] + for each in worker_metadata + ]) + + for each, ip in zip(worker_metadata, worker_ips): + each.ip = ip if not self.use_ray_spmd_worker: - for i in range(len(workers)): - worker = workers[i] - worker_ip = worker_ips[i] + for i, each in enumerate(worker_metadata): + # find and remove the dummy worker from the list + worker = each.worker + worker_ip = each.ip if self.driver_dummy_worker is None and worker_ip == driver_ip: # If the worker is on the same node as the driver, we use it # as the resource holder for the driver process. self.driver_dummy_worker = worker self.driver_worker = RayWorkerWrapper( - vllm_config=self.vllm_config) - workers.pop(i) - worker_ips.pop(i) - self.workers = workers + vllm_config=self.vllm_config, rank=0) + worker_metadata.pop(i) break - else: - self.workers = workers - logger.debug("workers: %s", self.workers) + logger.debug("workers: %s", worker_metadata) logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker) if not self.use_ray_spmd_worker and self.driver_dummy_worker is None: raise ValueError( @@ -176,9 +220,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", for ip in worker_ips: ip_counts[ip] = ip_counts.get(ip, 0) + 1 - worker_to_ip = dict(zip(self.workers, worker_ips)) - - def sort_by_driver_then_worker_ip(worker): + def sort_by_driver_then_worker_ip(item: RayWorkerMetaData): """ Sort the workers based on 3 properties: 1. If the worker is on the same node as the driver (vllm engine), @@ -188,13 +230,23 @@ def sort_by_driver_then_worker_ip(worker): 3. Finally, if the work is on a node with smaller IP address, it should be placed first. """ - ip = worker_to_ip[worker] - return (ip != driver_ip, ip_counts[ip], ip) + ip = item.ip + return (0 if ip == driver_ip else 1, ip_counts[ip], ip) # After sorting, the workers on the same node will be # close to each other, and the workers on the driver # node will be placed first. - self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip) + sorted_worker_metadata = sorted(worker_metadata, + key=sort_by_driver_then_worker_ip) + start_rank = 0 if self.use_ray_spmd_worker else 1 + for i, item in enumerate(sorted_worker_metadata): + item.adjusted_rank = i + start_rank + self.workers = [item.worker for item in sorted_worker_metadata] + rerank_mapping = { + item.created_rank: item.adjusted_rank + for item in sorted_worker_metadata + } + self._run_workers("adjust_rank", rerank_mapping) # Get the set of GPU IDs used on each node. worker_node_and_gpu_ids = [] @@ -235,21 +287,29 @@ def sort_by_driver_then_worker_ip(worker): " each node.") # Set environment variables for the driver and workers. - all_args_to_update_environment_variables = [({ - "CUDA_VISIBLE_DEVICES": + all_args_to_update_environment_variables = [{ + current_platform.device_control_env_var: ",".join(map(str, node_gpus[node_id])), - "VLLM_TRACE_FUNCTION": - str(envs.VLLM_TRACE_FUNCTION), - **({ - "VLLM_ATTENTION_BACKEND": envs.VLLM_ATTENTION_BACKEND - } if envs.VLLM_ATTENTION_BACKEND is not None else {}) - }, ) for (node_id, _) in worker_node_and_gpu_ids] + } for (node_id, _) in worker_node_and_gpu_ids] + + for args in all_args_to_update_environment_variables: + # some carry-over env vars from the driver + # TODO: refactor platform-specific env vars + for name in [ + "VLLM_ATTENTION_BACKEND", + "TPU_CHIPS_PER_HOST_BOUNDS", + "TPU_HOST_BOUNDS", + "VLLM_USE_V1", + "VLLM_TRACE_FUNCTION", + ]: + if name in os.environ: + args[name] = os.environ[name] self._env_vars_for_all_workers = ( all_args_to_update_environment_variables) self._run_workers("update_environment_variables", - all_args=self._get_env_vars_to_be_updated()) + self._get_env_vars_to_be_updated()) if len(node_gpus) == 1: # in single node case, we don't need to get the IP address. @@ -265,14 +325,19 @@ def sort_by_driver_then_worker_ip(worker): driver_ip, get_open_port()) # Initialize the actual workers inside worker wrapper. - init_worker_all_kwargs = [ - self._get_worker_kwargs( - local_rank=node_workers[node_id].index(rank), + all_kwargs = [] + for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids): + local_rank = node_workers[node_id].index(rank) + kwargs = dict( + vllm_config=self.vllm_config, + local_rank=local_rank, rank=rank, distributed_init_method=distributed_init_method, - ) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids) - ] - self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs) + is_driver_worker=(not self.parallel_config) + or (rank % self.parallel_config.tensor_parallel_size == 0), + ) + all_kwargs.append(kwargs) + self._run_workers("init_worker", all_kwargs) self._run_workers("init_device") self._run_workers("load_model", @@ -332,9 +397,15 @@ def execute_model( if self.forward_dag is None: self.forward_dag = self._compiled_ray_dag(enable_asyncio=False) - serialized_data = self.input_encoder.encode(execute_model_req) + if self.use_v1: + serialized_data = execute_model_req + else: + serialized_data = self.input_encoder.encode(execute_model_req) outputs = ray.get(self.forward_dag.execute(serialized_data)) - output = self.output_decoder.decode(outputs[0]) + if self.use_v1: + output = outputs[0] + else: + output = self.output_decoder.decode(outputs[0]) return output def _run_workers( @@ -342,8 +413,6 @@ def _run_workers( method: str, *args, async_run_tensor_parallel_workers_only: bool = False, - all_args: Optional[List[Tuple[Any, ...]]] = None, - all_kwargs: Optional[List[Dict[str, Any]]] = None, max_concurrent_workers: Optional[int] = None, **kwargs, ) -> Any: @@ -356,8 +425,6 @@ def _run_workers( It will also be run asynchronously and return a list of futures rather than blocking on the results. - args/kwargs: All workers share the same args/kwargs - - all_args/all_kwargs: args/kwargs for each worker are specified - individually """ if self.use_ray_spmd_worker: assert not async_run_tensor_parallel_workers_only, ( @@ -368,26 +435,13 @@ def _run_workers( raise NotImplementedError( "max_concurrent_workers is not supported yet.") - count = len(self.workers) if not \ - async_run_tensor_parallel_workers_only \ - else len(self.non_driver_workers) - # If using SPMD worker, all workers are the same, so we should execute - # the args on all workers. Otherwise, we skip the first worker's args - # because those args will go to the driver worker. - first_worker_args_index: int = 0 if self.use_ray_spmd_worker else 1 - all_worker_args = repeat(args, count) if all_args is None \ - else islice(all_args, first_worker_args_index, None) - all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \ - else islice(all_kwargs, first_worker_args_index, None) - # Start the ray workers first. ray_workers = self.workers if async_run_tensor_parallel_workers_only: ray_workers = self.non_driver_workers ray_worker_outputs = [ - worker.execute_method.remote(method, *worker_args, **worker_kwargs) - for (worker, worker_args, worker_kwargs - ) in zip(ray_workers, all_worker_args, all_worker_kwargs) + worker.execute_method.remote(method, *args, **kwargs) + for worker in ray_workers ] if async_run_tensor_parallel_workers_only: @@ -399,13 +453,9 @@ def _run_workers( # so we only explicitly execute on the driver worker if using a # non-SPMD worker class. if not self.use_ray_spmd_worker: - driver_args = args if all_args is None else all_args[0] - driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] - # Start the driver worker after all the ray workers. driver_worker_output = [ - self.driver_worker.execute_method(method, *driver_args, - **driver_kwargs) + self.driver_worker.execute_method(method, *args, **kwargs) ] # Get the results of the ray workers. @@ -467,11 +517,18 @@ def _compiled_ray_dag(self, enable_asyncio: bool): for pp_rank, tp_group in enumerate(self.pp_tp_workers): # Each PP worker takes in the output of the previous PP worker, # and the TP group executes in SPMD fashion. - outputs = [ - worker.execute_model_spmd. - bind( # type: ignore[attr-defined] - outputs[i]) for i, worker in enumerate(tp_group) - ] + if self.use_v1: + outputs = [ + worker.execute_model. + bind( # type: ignore[attr-defined] + outputs[i]) for i, worker in enumerate(tp_group) + ] + else: + outputs = [ + worker.execute_model_spmd. + bind( # type: ignore[attr-defined] + outputs[i]) for i, worker in enumerate(tp_group) + ] last_pp_rank = len(self.pp_tp_workers) - 1 if pp_rank < last_pp_rank: @@ -497,17 +554,6 @@ def _compiled_ray_dag(self, enable_asyncio: bool): def __del__(self): self.shutdown() - -class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.pp_locks: Optional[List[asyncio.Lock]] = None - self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER - if not self.use_ray_compiled_dag: - self.driver_exec_method = make_async( - self.driver_worker.execute_method) - async def execute_model_async( self, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: @@ -568,5 +614,7 @@ async def _start_worker_execution_loop(self): ] return await asyncio.gather(*coros) - def __del__(self): - self.shutdown() + def check_health(self) -> None: + # Assume that the Ray workers are healthy. + # TODO: check the health of the Ray workers + return diff --git a/vllm/executor/ray_hpu_executor.py b/vllm/executor/ray_hpu_executor.py deleted file mode 100644 index f3025cb537ab8..0000000000000 --- a/vllm/executor/ray_hpu_executor.py +++ /dev/null @@ -1,515 +0,0 @@ -import asyncio -import os -from collections import defaultdict -from itertools import islice, repeat -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple - -import msgspec - -import vllm.envs as envs -from vllm.executor.distributed_gpu_executor import ( # yapf: disable - DistributedGPUExecutor, DistributedGPUExecutorAsync) -from vllm.executor.msgspec_utils import encode_hook -from vllm.executor.ray_utils import RayWorkerWrapper, ray -from vllm.logger import init_logger -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.sequence import ExecuteModelRequest -from vllm.utils import (_run_task_with_lock, get_distributed_init_method, - get_ip, get_open_port, make_async) - -if ray is not None: - from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy - -if TYPE_CHECKING: - from ray.util.placement_group import PlacementGroup - -logger = init_logger(__name__) - - -class RayHPUExecutor(DistributedGPUExecutor): - - uses_ray: bool = True - - def _init_executor(self) -> None: - self.forward_dag: Optional[ray.dag.CompiledDAG] = None - # If the env var is set, it uses the Ray's compiled DAG API - # which optimizes the control plane overhead. - # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. - # Currently, this requires USE_RAY_SPMD_WORKER=True. - self.use_ray_compiled_dag = envs.VLLM_USE_RAY_COMPILED_DAG - # If the env var is set, then we do not distinguish between the - # "driver worker" vs other workers. Also, the rank 0 worker will - # be executed in a remote Ray worker. Currently this requires - # USE_RAY_COMPILED_DAG=True. - self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER - if self.use_ray_compiled_dag: - assert self.use_ray_spmd_worker, ( - "VLLM_USE_RAY_COMPILED_DAG=1 requires " - "VLLM_USE_RAY_SPMD_WORKER=1") - if self.use_ray_spmd_worker: - # TODO: Support SPMD worker for non-DAG Ray executor. - assert self.use_ray_compiled_dag, ( - "VLLM_USE_RAY_SPMD_WORKER=1 requires " - "VLLM_USE_RAY_COMPILED_DAG=1") - - assert self.uses_ray - placement_group = self.parallel_config.placement_group - - # Disable Ray usage stats collection. - ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0") - if ray_usage != "1": - os.environ["RAY_USAGE_STATS_ENABLED"] = "0" - - # Create the parallel GPU workers. - self._init_workers_ray(placement_group) - - self.input_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook) - self.output_decoder = msgspec.msgpack.Decoder( - Optional[List[SamplerOutput]]) - - def shutdown(self) -> None: - if hasattr(self, "forward_dag") and self.forward_dag is not None: - self.forward_dag.teardown() - import ray - for worker in self.workers: - ray.kill(worker) - self.forward_dag = None - - def finish_measurements(self): - self._run_workers("finish_measurements") - - def _init_workers_ray(self, placement_group: "PlacementGroup", - **ray_remote_kwargs): - # Otherwise, the ray workers are allocated with a full GPU. - num_gpus = 1 - - # The driver dummy worker does not actually use any resources. - # It holds the resource for the driver worker. - self.driver_dummy_worker: Optional[RayWorkerWrapper] = None - # The remaining workers are the actual ray actors. - self.workers: List[RayWorkerWrapper] = [] - - # Used in ray compiled DAG: indexed first by PP rank, - # and then TP rank. In other words, the inner list is - # the TP group of workers for a PP rank. - self.pp_tp_workers: List[List[RayWorkerWrapper]] = [] - - logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker) - - # Create the workers. - driver_ip = get_ip() - for bundle_id, bundle in enumerate(placement_group.bundle_specs): - if not bundle.get("HPU", 0): - continue - scheduling_strategy = PlacementGroupSchedulingStrategy( - placement_group=placement_group, - placement_group_capture_child_tasks=True, - placement_group_bundle_index=bundle_id, - ) - - worker = ray.remote( - num_cpus=0, - num_gpus=0, - resources={'HPU': num_gpus}, - scheduling_strategy=scheduling_strategy, - **ray_remote_kwargs, - )(RayWorkerWrapper).remote(vllm_config=self.vllm_config) - - if self.use_ray_spmd_worker: - self.workers.append(worker) - else: - worker_ip = ray.get(worker.get_node_ip.remote()) - if worker_ip == driver_ip and self.driver_dummy_worker is None: - # If the worker is on the same node as the driver, we use it - # as the resource holder for the driver process. - self.driver_dummy_worker = worker - self.driver_worker = RayWorkerWrapper( - vllm_config=self.vllm_config) - else: - # Else, added to the list of workers. - self.workers.append(worker) - - logger.debug("workers: %s", self.workers) - logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker) - if not self.use_ray_spmd_worker and self.driver_dummy_worker is None: - raise ValueError( - "Ray does not allocate any GPUs on the driver node. Consider " - "adjusting the Ray placement group or running the driver on a " - "GPU node.") - - worker_ips = [ - ray.get(worker.get_node_ip.remote()) # type: ignore[attr-defined] - for worker in self.workers - ] - ip_counts: Dict[str, int] = {} - for ip in worker_ips: - ip_counts[ip] = ip_counts.get(ip, 0) + 1 - - def sort_by_driver_then_worker_ip(worker): - """ - Sort the workers based on 3 properties: - 1. If the worker is on the same node as the driver (vllm engine), - it should be placed first. - 2. Then, if the worker is on a node with fewer workers, it should - be placed first. - 3. Finally, if the work is on a node with smaller IP address, it - should be placed first. - """ - ip = ray.get(worker.get_node_ip.remote()) - return (ip != driver_ip, ip_counts[ip], ip) - - # After sorting, the workers on the same node will be - # close to each other, and the workers on the driver - # node will be placed first. - self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip) - - worker_node_and_gpu_ids = [] - for worker in [self.driver_dummy_worker] + self.workers: - if worker is None: - # driver_dummy_worker can be None when using ray spmd worker. - continue - worker_node_and_gpu_ids.append( - ray.get(worker.get_node_and_gpu_ids.remote()) \ - ) # type: ignore - - node_workers = defaultdict(list) # node id -> list of worker ranks - node_gpus = defaultdict(list) # node id -> list of gpu ids - - for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids): - node_workers[node_id].append(i) - # `gpu_ids` can be a list of strings or integers. - # convert them to integers for consistency. - # NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs), - # string sorting is not sufficient. - # see https://github.com/vllm-project/vllm/issues/5590 - gpu_ids = [int(x) for x in gpu_ids] - node_gpus[node_id].extend(gpu_ids) - for node_id, gpu_ids in node_gpus.items(): - node_gpus[node_id] = sorted(gpu_ids) - - all_ips = set(worker_ips + [driver_ip]) - n_ips = len(all_ips) - n_nodes = len(node_workers) - - if n_nodes != n_ips: - raise RuntimeError( - f"Every node should have a unique IP address. Got {n_nodes}" - f" nodes with node ids {list(node_workers.keys())} and " - f"{n_ips} unique IP addresses {all_ips}. Please check your" - " network configuration. If you set `VLLM_HOST_IP` " - "environment variable, make sure it is unique for" - " each node.") - - # Set environment variables for the driver and workers. - all_args_to_update_environment_variables = [({ - "VLLM_TRACE_FUNCTION": - str(envs.VLLM_TRACE_FUNCTION), - }, ) for (node_id, _) in worker_node_and_gpu_ids] - self._run_workers("update_environment_variables", - all_args=all_args_to_update_environment_variables) - - if len(node_gpus) == 1: - # in single node case, we don't need to get the IP address. - # the loopback address is sufficient - # NOTE: a node may have several IP addresses, one for each - # network interface. `get_ip()` might return any of them, - # while they might not work for communication inside the node - # if the network setup is complicated. Using the loopback address - # solves this issue, as it always works for communication inside - # the node. - driver_ip = "127.0.0.1" - distributed_init_method = get_distributed_init_method( - driver_ip, get_open_port()) - - # Initialize the actual workers inside worker wrapper. - init_worker_all_kwargs = [ - self._get_worker_kwargs( - local_rank=node_workers[node_id].index(rank), - rank=rank, - distributed_init_method=distributed_init_method, - ) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids) - ] - self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs) - - self._run_workers("init_device") - self._run_workers("load_model", - max_concurrent_workers=self.parallel_config. - max_parallel_loading_workers) - - if self.use_ray_spmd_worker: - for pp_rank in range(self.parallel_config.pipeline_parallel_size): - self.pp_tp_workers.append([]) - for tp_rank in range( - self.parallel_config.tensor_parallel_size): - # PP=2, TP=4 - # pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]] - rank = (pp_rank * self.parallel_config.tensor_parallel_size - ) + tp_rank - assert len(self.pp_tp_workers[pp_rank]) == tp_rank - assert pp_rank < len(self.pp_tp_workers) - self.pp_tp_workers[pp_rank].append(self.workers[rank]) - - # This is the list of workers that are rank 0 of each TP group EXCEPT - # global rank 0. These are the workers that will broadcast to the - # rest of the workers. - self.tp_driver_workers: List[RayWorkerWrapper] = [] - # This is the list of workers that are not drivers and not the first - # worker in a TP group. These are the workers that will be - # broadcasted to. - self.non_driver_workers: List[RayWorkerWrapper] = [] - - # Enforce rank order for correct rank to return final output. - for index, worker in enumerate(self.workers): - # The driver worker is rank 0 and not in self.workers. - rank = index + 1 - if rank % self.parallel_config.tensor_parallel_size == 0: - self.tp_driver_workers.append(worker) - else: - self.non_driver_workers.append(worker) - - def _driver_execute_model( - self, execute_model_req: Optional[ExecuteModelRequest] - ) -> Optional[List[SamplerOutput]]: - """Run execute_model in the driver worker. - - Passing None will cause the driver to stop the model execution - loop running in each of the remote workers. - """ - assert not self.use_ray_spmd_worker, ( - "driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1") - return self.driver_worker.execute_method("execute_model", - execute_model_req) - - def execute_model( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: - if not self.use_ray_spmd_worker: - return super().execute_model(execute_model_req) - - if self.forward_dag is None: - self.forward_dag = self._compiled_ray_dag(enable_asyncio=False) - - serialized_data = self.input_encoder.encode(execute_model_req) - outputs = ray.get(self.forward_dag.execute(serialized_data)) - output = self.output_decoder.decode(outputs[0]) - return output - - def _run_workers( - self, - method: str, - *args, - async_run_tensor_parallel_workers_only: bool = False, - all_args: Optional[List[Tuple[Any, ...]]] = None, - all_kwargs: Optional[List[Dict[str, Any]]] = None, - max_concurrent_workers: Optional[int] = None, - **kwargs, - ) -> Any: - """Runs the given method on all workers. Can be used in the following - ways: - - Args: - - async_run_tensor_parallel_workers_only: If True the method will be - run only in the remote TP workers, not the driver worker. - It will also be run asynchronously and return a list of futures - rather than blocking on the results. - - args/kwargs: All workers share the same args/kwargs - - all_args/all_kwargs: args/kwargs for each worker are specified - individually - """ - if self.use_ray_spmd_worker: - assert not async_run_tensor_parallel_workers_only, ( - "async_run_tensor_parallel_workers_only is not supported for " - "spmd mode.") - - if max_concurrent_workers: - raise NotImplementedError( - "max_concurrent_workers is not supported yet.") - - count = len(self.workers) if not \ - async_run_tensor_parallel_workers_only \ - else len(self.non_driver_workers) - # If using SPMD worker, all workers are the same, so we should execute - # the args on all workers. Otherwise, we skip the first worker's args - # because those args will go to the driver worker. - first_worker_args_index: int = 0 if self.use_ray_spmd_worker else 1 - all_worker_args = repeat(args, count) if all_args is None \ - else islice(all_args, first_worker_args_index, None) - all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \ - else islice(all_kwargs, first_worker_args_index, None) - - # Start the ray workers first. - ray_workers = self.workers - if async_run_tensor_parallel_workers_only: - ray_workers = self.non_driver_workers - ray_worker_outputs = [ - worker.execute_method.remote(method, *worker_args, **worker_kwargs) - for (worker, worker_args, worker_kwargs - ) in zip(ray_workers, all_worker_args, all_worker_kwargs) - ] - - if async_run_tensor_parallel_workers_only: - # Just return futures - return ray_worker_outputs - - driver_worker_output = [] - # In SPMD mode, the driver worker is the same as any other worker, - # so we only explicitly execute on the driver worker if using a - # non-SPMD worker class. - if not self.use_ray_spmd_worker: - driver_args = args if all_args is None else all_args[0] - driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] - - # Start the driver worker after all the ray workers. - driver_worker_output = [ - self.driver_worker.execute_method(method, *driver_args, - **driver_kwargs) - ] - - # Get the results of the ray workers. - if self.workers: - ray_worker_outputs = ray.get(ray_worker_outputs) - - return driver_worker_output + ray_worker_outputs - - def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: - """Wait for futures returned from _run_workers() with - async_run_remote_workers_only to complete.""" - ray.get(parallel_worker_tasks) - - def _check_ray_adag_installation(self): - import pkg_resources - from packaging import version - - required_version = version.parse("2.35") - current_version = version.parse( - pkg_resources.get_distribution("ray").version) - # TODO: update the constraint once we adapt to the backward - # incompatible API change from ray 2.36 - if current_version != required_version: - raise ValueError(f"Ray version {required_version} is " - f"required, but found {current_version}") - - import importlib.util - adag_spec = importlib.util.find_spec( - "ray.experimental.compiled_dag_ref") - if adag_spec is None: - raise ValueError("Ray accelerated DAG is not installed. " - "Run `pip install ray[adag]` to install it.") - - def _compiled_ray_dag(self, enable_asyncio: bool): - assert self.parallel_config.use_ray - self._check_ray_adag_installation() - from ray.dag import InputNode, MultiOutputNode - from ray.experimental.channel.torch_tensor_type import TorchTensorType - - with InputNode() as input_data: - # Example DAG: PP=2, TP=4 - # (ExecuteModelReq, None) -> 0 -> (ExecuteModelReq, IntermediateOutput) -> 4 -> SamplerOutput # noqa: E501 - # -> 1 -> (ExecuteModelReq, IntermediateOutput) -> 5 -> SamplerOutput # noqa: E501 - # -> 2 -> (ExecuteModelReq, IntermediateOutput) -> 6 -> SamplerOutput # noqa: E501 - # -> 3 -> (ExecuteModelReq, IntermediateOutput) -> 7 -> SamplerOutput # noqa: E501 - - # All workers in the first TP group will take in the - # ExecuteModelRequest as input. - outputs = [input_data for _ in self.pp_tp_workers[0]] - for pp_rank, tp_group in enumerate(self.pp_tp_workers): - # Each PP worker takes in the output of the previous PP worker, - # and the TP group executes in SPMD fashion. - outputs = [ - worker.execute_model_spmd. - bind( # type: ignore[attr-defined] - outputs[i]) for i, worker in enumerate(tp_group) - ] - - last_pp_rank = len(self.pp_tp_workers) - 1 - if pp_rank < last_pp_rank: - # Specify how intermediate tensors should be passed - # between pp stages, no need to specify for the last - # pp stage. - transport = "auto" - outputs = [ - output.with_type_hint( - TorchTensorType(transport=transport)) - for output in outputs - ] - - forward_dag = MultiOutputNode(outputs) - - return forward_dag.experimental_compile(enable_asyncio=enable_asyncio) - - def __del__(self): - self.shutdown() - - -class RayHPUExecutorAsync(RayHPUExecutor, DistributedGPUExecutorAsync): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.pp_locks: Optional[List[asyncio.Lock]] = None - self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER - if not self.use_ray_compiled_dag: - self.driver_exec_method = make_async( - self.driver_worker.execute_method) - - async def execute_model_async( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: - if not self.use_ray_spmd_worker: - return await super().execute_model_async(execute_model_req) - - if self.forward_dag is None: - self.forward_dag = self._compiled_ray_dag(enable_asyncio=True) - - serialized_data = self.input_encoder.encode(execute_model_req) - dag_future = await self.forward_dag.execute_async(serialized_data) - outputs = await dag_future - return self.output_decoder.decode(outputs[0]) - - async def _driver_execute_model_async( - self, - execute_model_req: Optional[ExecuteModelRequest] = None - ) -> List[SamplerOutput]: - assert not self.use_ray_spmd_worker, ( - "driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1") - if not self.tp_driver_workers: - return await self.driver_exec_method("execute_model", - execute_model_req) - if self.pp_locks is None: - # This locks each pipeline parallel stage so multiple virtual - # engines can't execute on the same stage at the same time - # We create the locks here to avoid creating them in the constructor - # which uses a different asyncio loop. - self.pp_locks = [ - asyncio.Lock() - for _ in range(self.parallel_config.pipeline_parallel_size) - ] - - tasks = [ - asyncio.create_task( - _run_task_with_lock(self.driver_exec_method, self.pp_locks[0], - "execute_model", execute_model_req)) - ] - for pp_rank, driver_worker in enumerate(self.tp_driver_workers, - start=1): - tasks.append( - asyncio.create_task( - _run_task_with_lock(driver_worker.execute_method.remote, - self.pp_locks[pp_rank], - "execute_model", execute_model_req))) - - results = await asyncio.gather(*tasks) - - # Only the last PP stage has the final results. - return results[-1] - - async def _start_worker_execution_loop(self): - assert not self.use_ray_spmd_worker, ( - "worker loop is disabled for VLLM_USE_RAY_SPMD_WORKER=1") - coros = [ - worker.execute_method.remote("start_worker_execution_loop") - for worker in self.non_driver_workers - ] - return await asyncio.gather(*coros) - - def __del__(self): - self.shutdown() diff --git a/vllm/executor/ray_tpu_executor.py b/vllm/executor/ray_tpu_executor.py deleted file mode 100644 index 5118c13934f0d..0000000000000 --- a/vllm/executor/ray_tpu_executor.py +++ /dev/null @@ -1,343 +0,0 @@ -import asyncio -import os -from collections import defaultdict -from itertools import islice, repeat -from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple, - Union) - -import vllm.envs as envs -from vllm.executor.executor_base import ExecutorAsyncBase -from vllm.executor.ray_utils import RayWorkerWrapper, ray -from vllm.executor.tpu_executor import TPUExecutor -from vllm.logger import init_logger -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.sequence import ExecuteModelRequest -from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, - make_async) - -if ray is not None: - from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy - -if TYPE_CHECKING: - from ray.util.placement_group import PlacementGroup - -logger = init_logger(__name__) - - -class RayTPUExecutor(TPUExecutor): - - uses_ray: bool = True - - def __init__(self, *args, **kwargs): - # This is non-None when the execute model loop is running - # in the parallel workers. It's a coroutine in the AsyncLLMEngine case. - self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None - # Updated by implementations that require additional args to be passed - # to the _run_workers execute_model call - self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {} - - super().__init__(*args, **kwargs) - - def _init_executor(self) -> None: - assert self.parallel_config.distributed_executor_backend == "ray" - placement_group = self.parallel_config.placement_group - - # Disable Ray usage stats collection. - ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0") - if ray_usage != "1": - os.environ["RAY_USAGE_STATS_ENABLED"] = "0" - - # Create the parallel TPU workers. - self._init_workers_ray(placement_group) - - def _init_workers_ray(self, placement_group: "PlacementGroup", - **ray_remote_kwargs): - # The driver dummy worker does not actually use any resources. - # It holds the resource for the driver worker. - self.driver_dummy_worker: Optional[RayWorkerWrapper] = None - # The remaining workers are the actual ray actors. - self.workers: List[RayWorkerWrapper] = [] - - # Create the workers. - driver_ip = get_ip() - for bundle_id, bundle in enumerate(placement_group.bundle_specs): - if not bundle.get("TPU", 0): - continue - scheduling_strategy = PlacementGroupSchedulingStrategy( - placement_group=placement_group, - placement_group_capture_child_tasks=True, - placement_group_bundle_index=bundle_id, - ) - - # GKE does not fetch environment information from metadata server - # and instead sets these from within the Ray process. Therefore we - # need to override the Ray environment variables manually. - override_env = {} - if "TPU_CHIPS_PER_HOST_BOUNDS" in os.environ: - override_env.update({ - "TPU_CHIPS_PER_HOST_BOUNDS": - os.environ["TPU_CHIPS_PER_HOST_BOUNDS"] - }) - if "TPU_HOST_BOUNDS" in os.environ: - override_env.update( - {"TPU_HOST_BOUNDS": os.environ["TPU_HOST_BOUNDS"]}) - - worker = ray.remote( - num_cpus=0, - resources={"TPU": 1}, - scheduling_strategy=scheduling_strategy, - **ray_remote_kwargs, - )(RayWorkerWrapper).remote(vllm_config=self.vllm_config) - if override_env: - worker.override_env_vars.remote(override_env) - - worker_ip = ray.get(worker.get_node_ip.remote()) - if worker_ip == driver_ip and self.driver_dummy_worker is None: - # If the worker is on the same node as the driver, we use it - # as the resource holder for the driver process. - self.driver_dummy_worker = worker - self.driver_worker = RayWorkerWrapper( - vllm_config=self.vllm_config) - else: - # Else, added to the list of workers. - self.workers.append(worker) - - logger.debug("workers: %s", self.workers) - logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker) - if self.driver_dummy_worker is None: - raise ValueError( - "Ray does not allocate any TPUs on the driver node. Consider " - "adjusting the Ray placement group or running the driver on a " - "TPU node.") - - worker_ips = [ - ray.get(worker.get_node_ip.remote()) # type: ignore[attr-defined] - for worker in self.workers - ] - ip_counts: Dict[str, int] = {} - for ip in worker_ips: - ip_counts[ip] = ip_counts.get(ip, 0) + 1 - - def sort_by_driver_then_worker_ip(worker): - """ - Sort the workers based on 3 properties: - 1. If the worker is on the same node as the driver (vllm engine), - it should be placed first. - 2. Then, if the worker is on a node with fewer workers, it should - be placed first. - 3. Finally, if the work is on a node with smaller IP address, it - should be placed first. - """ - ip = ray.get(worker.get_node_ip.remote()) - return (ip != driver_ip, ip_counts[ip], ip) - - # After sorting, the workers on the same node will be - # close to each other, and the workers on the driver - # node will be placed first. - self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip) - - # Get the set of TPU IDs used on each node. - worker_node_and_gpu_ids = [] - for worker in [self.driver_dummy_worker] + self.workers: - if worker is None: - # driver_dummy_worker can be None when using ray spmd worker. - continue - worker_node_and_gpu_ids.append( - ray.get(worker.get_node_and_gpu_ids.remote()) \ - ) # type: ignore - - node_workers = defaultdict(list) - for i, (node_id, _) in enumerate(worker_node_and_gpu_ids): - node_workers[node_id].append(i) - - # Set environment variables for the driver and workers. - all_args_to_update_environment_variables = [({ - "VLLM_TRACE_FUNCTION": - str(envs.VLLM_TRACE_FUNCTION), - }, ) for _ in worker_node_and_gpu_ids] - self._run_workers("update_environment_variables", - all_args=all_args_to_update_environment_variables) - - if len(node_workers) == 1: - # in single node case, we don't need to get the IP address. - # the loopback address is sufficient - # NOTE: a node may have several IP addresses, one for each - # network interface. `get_ip()` might return any of them, - # while they might not work for communication inside the node - # if the network setup is complicated. Using the loopback address - # solves this issue, as it always works for communication inside - # the node. - driver_ip = "127.0.0.1" - distributed_init_method = get_distributed_init_method( - driver_ip, get_open_port()) - - # Initialize the actual workers inside worker wrapper. - init_worker_all_kwargs = [ - self._get_worker_kwargs( - local_rank=node_workers[node_id].index(rank), - rank=rank, - distributed_init_method=distributed_init_method, - ) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids) - ] - self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs) - - self._run_workers("init_device") - self._run_workers("load_model", - max_concurrent_workers=self.parallel_config. - max_parallel_loading_workers) - - def _driver_execute_model( - self, - execute_model_req: Optional[ExecuteModelRequest] = None - ) -> List[SamplerOutput]: - """Run execute_model in the driver worker. - - Passing None will cause the driver to stop the model execution - loop running in each of the remote workers. - """ - return self.driver_worker.execute_method("execute_model", - execute_model_req) - - def _run_workers( - self, - method: str, - *args, - async_run_remote_workers_only: bool = False, - all_args: Optional[List[Tuple[Any, ...]]] = None, - all_kwargs: Optional[List[Dict[str, Any]]] = None, - max_concurrent_workers: Optional[int] = None, - use_ray_compiled_dag: bool = False, - **kwargs, - ) -> Any: - """Runs the given method on all workers. Can be used in the following - ways: - - - async_run_remote_workers_only: If True the method will be run only - in the remote workers, not the driver worker. It will also be - run asynchronously and return a list of futures rather than blocking - on the results. - - args/kwargs: All workers share the same args/kwargs - - all_args/all_kwargs: args/kwargs for each worker are specified - individually - """ - - if max_concurrent_workers: - raise NotImplementedError( - "max_concurrent_workers is not supported yet.") - - count = len(self.workers) - all_worker_args = repeat(args, count) if all_args is None \ - else islice(all_args, 1, None) - all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \ - else islice(all_kwargs, 1, None) - - # Start the ray workers first. - ray_worker_outputs = [ - worker.execute_method.remote(method, *worker_args, **worker_kwargs) - for (worker, worker_args, worker_kwargs - ) in zip(self.workers, all_worker_args, all_worker_kwargs) - ] - - if async_run_remote_workers_only: - # Just return futures - return ray_worker_outputs - - driver_args = args if all_args is None else all_args[0] - driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] - - # Start the driver worker after all the ray workers. - driver_worker_output = self.driver_worker.execute_method( - method, *driver_args, **driver_kwargs) - # Get the results of the ray workers. - if self.workers: - ray_worker_outputs = ray.get(ray_worker_outputs) - - return [driver_worker_output] + ray_worker_outputs - - def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: - """Wait for futures returned from _run_workers() with - async_run_remote_workers_only to complete.""" - ray.get(parallel_worker_tasks) - - def determine_num_available_blocks(self) -> Tuple[int, int]: - num_blocks = self._run_workers("determine_num_available_blocks", ) - num_tpu_blocks = min(b[0] for b in num_blocks) - num_cpu_blocks = min(b[1] for b in num_blocks) - return num_tpu_blocks, num_cpu_blocks - - def initialize_cache(self, num_gpu_blocks: int, - num_cpu_blocks: int) -> None: - logger.info("# TPU blocks: %d, # CPU blocks: %d", num_gpu_blocks, - num_cpu_blocks) - self.cache_config.num_gpu_blocks = num_gpu_blocks - self.cache_config.num_cpu_blocks = num_cpu_blocks - self._run_workers("initialize_cache", - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=num_cpu_blocks) - - def execute_model( - self, - execute_model_req: ExecuteModelRequest, - ) -> List[SamplerOutput]: - if self.parallel_worker_tasks is None: - self.parallel_worker_tasks = self._run_workers( - "start_worker_execution_loop", - async_run_remote_workers_only=True, - **self.extra_execute_model_run_workers_kwargs) - - # Only the driver worker returns the sampling results. - return self._driver_execute_model(execute_model_req) - - def stop_remote_worker_execution_loop(self) -> None: - if self.parallel_worker_tasks is None: - return - - self._driver_execute_model() - parallel_worker_tasks = self.parallel_worker_tasks - self.parallel_worker_tasks = None - # Ensure that workers exit model loop cleanly - # (this will raise otherwise) - self._wait_for_tasks_completion(parallel_worker_tasks) - - -class RayTPUExecutorAsync(RayTPUExecutor, ExecutorAsyncBase): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.driver_exec_method = make_async(self.driver_worker.execute_method) - - async def execute_model_async( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: - if self.parallel_worker_tasks is None: - # Start model execution loop running in the parallel workers - self.parallel_worker_tasks = asyncio.create_task( - self._start_worker_execution_loop()) - - # Only the driver worker returns the sampling results. - return await self._driver_execute_model_async(execute_model_req) - - async def stop_remote_worker_execution_loop_async(self) -> None: - if self.parallel_worker_tasks is None: - return - - await self._driver_execute_model_async() - parallel_worker_tasks = self.parallel_worker_tasks - self.parallel_worker_tasks = None - # Ensure that workers exit model loop cleanly - # (this will raise otherwise) - await parallel_worker_tasks - - async def _driver_execute_model_async( - self, - execute_model_req: Optional[ExecuteModelRequest] = None - ) -> List[SamplerOutput]: - return await self.driver_exec_method("execute_model", - execute_model_req) - - async def _start_worker_execution_loop(self): - coros = [ - worker.execute_method.remote("start_worker_execution_loop") - for worker in self.workers - ] - return await asyncio.gather(*coros) diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index 9f40f6a65dcd7..e55155ea06225 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -1,7 +1,7 @@ import os import time from collections import defaultdict -from typing import Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import msgspec @@ -13,6 +13,10 @@ from vllm.utils import get_ip from vllm.worker.worker_base import WorkerWrapperBase +if TYPE_CHECKING: + from vllm.v1.core.scheduler import SchedulerOutput + from vllm.v1.outputs import ModelRunnerOutput + logger = init_logger(__name__) PG_WAIT_TIMEOUT = 1800 @@ -95,6 +99,26 @@ def execute_model_spmd( return output + def setup_device_if_necessary(self): + # TODO(swang): This is needed right now because Ray CG executes + # on a background thread, so we need to reset torch's current + # device. + # We can remove this API after it is fixed in compiled graph. + import torch + assert self.worker is not None, "Worker is not initialized" + if not self.compiled_dag_cuda_device_set: + torch.cuda.set_device(self.worker.device) + self.compiled_dag_cuda_device_set = True + + def execute_model( + self, + scheduler_output: "SchedulerOutput", + ) -> "ModelRunnerOutput": + self.setup_device_if_necessary() + assert self.worker is not None, "Worker is not initialized" + output = self.worker.model_runner.execute_model(scheduler_output) + return output + def override_env_vars(self, vars: Dict[str, str]): os.environ.update(vars) diff --git a/vllm/executor/ray_xpu_executor.py b/vllm/executor/ray_xpu_executor.py deleted file mode 100644 index d2086f5fef26c..0000000000000 --- a/vllm/executor/ray_xpu_executor.py +++ /dev/null @@ -1,40 +0,0 @@ -import asyncio -from typing import List, Optional - -import ray - -import vllm.envs as envs -from vllm.executor.ray_gpu_executor import RayGPUExecutor, RayGPUExecutorAsync -from vllm.executor.xpu_executor import XPUExecutor -from vllm.logger import init_logger -from vllm.utils import make_async - -logger = init_logger(__name__) - - -class RayXPUExecutor(RayGPUExecutor, XPUExecutor): - - def _get_env_vars_to_be_updated(self): - # Get the set of GPU IDs used on each node. - worker_node_and_gpu_ids = [] - for worker in [self.driver_dummy_worker] + self.workers: - if worker is None: - # driver_dummy_worker can be None when using ray spmd worker. - continue - worker_node_and_gpu_ids.append( - ray.get(worker.get_node_and_gpu_ids.remote())) # type: ignore - - # Set environment variables for the driver and workers. - all_args_to_update_environment_variables = [({ - "VLLM_TRACE_FUNCTION": - str(envs.VLLM_TRACE_FUNCTION), - }, ) for (_, _) in worker_node_and_gpu_ids] - return all_args_to_update_environment_variables - - -class RayXPUExecutorAsync(RayXPUExecutor, RayGPUExecutorAsync): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.driver_exec_method = make_async(self.driver_worker.execute_method) - self.pp_locks: Optional[List[asyncio.Lock]] = None diff --git a/vllm/executor/tpu_executor.py b/vllm/executor/tpu_executor.py deleted file mode 100644 index e37e8973790db..0000000000000 --- a/vllm/executor/tpu_executor.py +++ /dev/null @@ -1,142 +0,0 @@ -from typing import Any, Dict, List, Optional, Set, Tuple - -import torch - -from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.sequence import ExecuteModelRequest -from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, - make_async) - -logger = init_logger(__name__) - - -class TPUExecutor(ExecutorBase): - - uses_ray: bool = False - - def _init_executor(self) -> None: - assert not self.scheduler_config.chunked_prefill_enabled, ( - "Chunked prefill is not yet supported for TPU backend") - assert not self.speculative_config, ( - "Speculative decoding is not yet supported for TPU backend") - if self.model_config.dtype in (torch.float16, torch.float32): - logger.warning( - "The TPU backend currently does not support %s. " - "Using bfloat16 instead.", self.model_config.dtype) - self.model_config.dtype = torch.bfloat16 - - # Instantiate the worker and load the model to the device. - self.driver_worker = self._create_worker() - self.driver_worker.init_device() - self.driver_worker.load_model() - - def _get_worker_kwargs( - self, - local_rank: int = 0, - rank: int = 0, - distributed_init_method: Optional[str] = None, - ) -> Dict[str, Any]: - """Return worker init args for a given rank.""" - if distributed_init_method is None: - distributed_init_method = get_distributed_init_method( - get_ip(), get_open_port()) - return dict( - vllm_config=self.vllm_config, - local_rank=local_rank, - rank=rank, - distributed_init_method=distributed_init_method, - is_driver_worker=rank == 0, - ) - - def _create_worker( - self, - local_rank: int = 0, - rank: int = 0, - distributed_init_method: Optional[str] = None, - ): - if self.scheduler_config.is_multi_step: - from vllm.worker.multi_step_tpu_worker import MultiStepTPUWorker - worker = MultiStepTPUWorker(**self._get_worker_kwargs( - local_rank, rank, distributed_init_method)) - return worker - else: - from vllm.worker.tpu_worker import TPUWorker - - worker = TPUWorker(**self._get_worker_kwargs( - local_rank, rank, distributed_init_method)) - return worker - - def initialize_cache( - self, - num_gpu_blocks: int, - num_cpu_blocks: int, - ) -> None: - """Initialize the KV cache by invoking the underlying worker.""" - # NOTE: This is logged in the executor because there can be >1 worker - # with other executors. We could log in the engine level, but work - # remains to abstract away the device for non-GPU configurations. - logger.info("# TPU blocks: %d, # CPU blocks: %d", num_gpu_blocks, - num_cpu_blocks) - self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) - - def determine_num_available_blocks(self) -> Tuple[int, int]: - """Determine the number of available KV blocks by invoking the - underlying worker.""" - return self.driver_worker.determine_num_available_blocks() - - def execute_model( - self, - execute_model_req: ExecuteModelRequest, - ) -> List[SamplerOutput]: - output = self.driver_worker.execute_model(execute_model_req) - return output - - def add_lora(self, lora_request: LoRARequest) -> bool: - raise NotImplementedError( - "LoRA is currently not supported by the TPU backend.") - - def remove_lora(self, lora_id: int) -> bool: - raise NotImplementedError( - "LoRA is currently not supported by the TPU backend.") - - def pin_lora(self, lora_id: int) -> bool: - raise NotImplementedError( - "LoRA is currently not supported by the TPU backend.") - - def list_loras(self) -> Set[int]: - raise NotImplementedError( - "LoRA is currently not supported by the TPU backend.") - - def add_prompt_adapter(self, prompt_adapter_request) -> bool: - raise NotImplementedError( - "Soft prompt is currently not supported by the TPU backend.") - - def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: - raise NotImplementedError( - "Soft prompt is currently not supported by the TPU backend.") - - def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: - raise NotImplementedError( - "Soft prompt is currently not supported by the TPU backend.") - - def list_prompt_adapters(self) -> Set[int]: - raise NotImplementedError( - "Soft prompt is currently not supported by the TPU backend.") - - def check_health(self) -> None: - # TPUExecutor will always be healthy as long as it's running. - return - - -class TPUExecutorAsync(TPUExecutor, ExecutorAsyncBase): - - async def execute_model_async( - self, - sexecute_model_req: ExecuteModelRequest, - ) -> SamplerOutput: - output = await make_async(self.driver_worker.execute_model - )(sexecute_model_req) - return output diff --git a/vllm/executor/uniproc_executor.py b/vllm/executor/uniproc_executor.py new file mode 100644 index 0000000000000..da1d77343cf3b --- /dev/null +++ b/vllm/executor/uniproc_executor.py @@ -0,0 +1,57 @@ +from typing import Any, Dict, List, Optional, Tuple + +from vllm.executor.executor_base import ExecutorBase +from vllm.logger import init_logger +from vllm.utils import get_distributed_init_method, get_ip, get_open_port +from vllm.worker.worker_base import WorkerWrapperBase + +logger = init_logger(__name__) + + +class UniProcExecutor(ExecutorBase): + + uses_ray: bool = False + + def _init_executor(self) -> None: + """Initialize the worker and load the model. + """ + self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, + rank=0) + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + local_rank = 0 + rank = 0 + kwargs = dict( + vllm_config=self.vllm_config, + local_rank=local_rank, + rank=rank, + distributed_init_method=distributed_init_method, + is_driver_worker=(not self.parallel_config) + or (rank % self.parallel_config.tensor_parallel_size == 0), + ) + self.collective_rpc("init_worker", args=([kwargs], )) + self.collective_rpc("init_device") + self.collective_rpc("load_model") + + def collective_rpc(self, + method: str, + timeout: Optional[float] = None, + args: Tuple = (), + kwargs: Optional[Dict] = None) -> List[Any]: + if kwargs is None: + kwargs = {} + try: + func = getattr(self.driver_worker, method) + except AttributeError: + raise NotImplementedError(f"Method {method} is not implemented.") \ + from None + answer = func(*args, **kwargs) + return [answer] + + def check_health(self) -> None: + # UniProcExecutor will always be healthy as long as + # it's running. + return + + +UniProcExecutorAsync = UniProcExecutor diff --git a/vllm/executor/xpu_executor.py b/vllm/executor/xpu_executor.py deleted file mode 100644 index 722b86a95ff8a..0000000000000 --- a/vllm/executor/xpu_executor.py +++ /dev/null @@ -1,39 +0,0 @@ -from typing import List, Optional, Union - -from vllm.executor.executor_base import ExecutorAsyncBase -from vllm.executor.gpu_executor import GPUExecutor -from vllm.logger import init_logger -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.sequence import ExecuteModelRequest, PoolerOutput -from vllm.utils import make_async - -logger = init_logger(__name__) - - -class XPUExecutor(GPUExecutor): - - uses_ray: bool = False - - def _init_executor(self) -> None: - assert self.device_config.device_type == "xpu" - assert self.speculative_config is None, ( - "Speculative decoding not yet supported for XPU backend") - - GPUExecutor._init_executor(self) - - def execute_model( - self, execute_model_req: ExecuteModelRequest - ) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]: - output = self.driver_worker.execute_model(execute_model_req) - return output - - -class XPUExecutorAsync(XPUExecutor, ExecutorAsyncBase): - - async def execute_model_async( - self, - execute_model_req: ExecuteModelRequest, - ) -> List[SamplerOutput]: - output = await make_async(self.driver_worker.execute_model - )(execute_model_req=execute_model_req) - return output diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 4d3b84fea887f..74948202cbe48 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -1,3 +1,4 @@ +import os from typing import TYPE_CHECKING, Optional import psutil @@ -105,6 +106,32 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: else: parallel_config.worker_cls = "vllm.worker.cpu_worker.CPUWorker" + assert vllm_config.device_config.device_type == "cpu" + + # + # Environment variables for CPU executor + # + + # Disable torch async compiling which won't work with daemonic processes + os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" + + # Intel OpenMP setting + ld_prealod_str = os.getenv("LD_PRELOAD", "") + if "libiomp5.so" in ld_prealod_str: + # The time(milliseconds) that a thread should wait after + # completing the execution of a parallel region, before sleeping. + os.environ['KMP_BLOCKTIME'] = "1" + # Prevents the CPU to run into low performance state + os.environ['KMP_TPAUSE'] = "0" + # Provides fine granularity parallelism + os.environ['KMP_FORKJOIN_BARRIER_PATTERN'] = "dist,dist" + os.environ['KMP_PLAIN_BARRIER_PATTERN'] = "dist,dist" + os.environ['KMP_REDUCTION_BARRIER_PATTERN'] = "dist,dist" + + # To hint IPEX uses shared memory based AllReduce + os.environ["LOCAL_WORLD_SIZE"] = str( + vllm_config.parallel_config.tensor_parallel_size) + @classmethod def is_pin_memory_available(cls) -> bool: logger.warning("Pin memory is not supported on CPU.") diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 2587e3a11dde3..8350177b68ade 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -139,6 +139,28 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: else: parallel_config.worker_cls = "vllm.worker.worker.Worker" + world_size = parallel_config.world_size + tensor_parallel_size = parallel_config.tensor_parallel_size + + from vllm.utils import (cuda_device_count_stateless, + update_environment_variables) + + # Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers + if "CUDA_VISIBLE_DEVICES" not in os.environ: + update_environment_variables({ + "CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size)))) + }) + + cuda_device_count = cuda_device_count_stateless() + # Use confusing message for more common TP-only case. + assert tensor_parallel_size <= cuda_device_count, ( + f"please set tensor_parallel_size ({tensor_parallel_size}) " + f"to less than max local gpu count ({cuda_device_count})") + + assert world_size <= cuda_device_count, ( + f"please ensure that world_size ({world_size}) " + f"is less than than max local gpu count ({cuda_device_count})") + cache_config = vllm_config.cache_config if cache_config and cache_config.block_size is None: cache_config.block_size = 16 diff --git a/vllm/platforms/neuron.py b/vllm/platforms/neuron.py index 0696f73cc17b4..ead3dab05a6b1 100644 --- a/vllm/platforms/neuron.py +++ b/vllm/platforms/neuron.py @@ -35,6 +35,14 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: parallel_config.worker_cls = \ "vllm.worker.neuron_worker.NeuronWorker" + if parallel_config.world_size > 1: + parallel_config.distributed_executor_backend = "uni" + + assert (vllm_config.lora_config is + None), "LoRA is not supported for Neuron backend." + assert (not vllm_config.speculative_config + ), "Speculative decoding not yet supported for Neuron backend." + cache_config = vllm_config.cache_config if cache_config: # neuron needs block_size = max_model_len diff --git a/vllm/platforms/openvino.py b/vllm/platforms/openvino.py index 9390eda535c8f..7d414165a8188 100644 --- a/vllm/platforms/openvino.py +++ b/vllm/platforms/openvino.py @@ -66,9 +66,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: from vllm.utils import GiB_bytes parallel_config = vllm_config.parallel_config - assert ( - parallel_config.world_size == 1 - ), "OpenVINOExecutor only supports single CPU socket currently." + assert (parallel_config.world_size == 1 + ), "OpenVINO only supports single CPU socket currently." if parallel_config.worker_cls == "auto": parallel_config.worker_cls = \ @@ -141,3 +140,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: raise RuntimeError( "Invalid environment variable VLLM_OPENVINO_KVCACHE_SPACE" f" {kv_cache_space}, expect a positive integer value.") + + assert vllm_config.device_config.device_type == "openvino" + assert vllm_config.lora_config is None, \ + "OpenVINO backend doesn't support LoRA" + assert cls.is_openvino_cpu() or \ + cls.is_openvino_gpu(), \ + "OpenVINO backend supports only CPU and GPU devices" diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index ff9487daac7a7..05a3aa4305cfa 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -72,6 +72,16 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: assert vllm_config.speculative_config is None, \ "TPU does not support speculative decoding" + assert not vllm_config.scheduler_config.chunked_prefill_enabled, ( + "Chunked prefill is not yet supported for TPU backend") + assert not vllm_config.speculative_config, ( + "Speculative decoding is not yet supported for TPU backend") + if vllm_config.model_config.dtype in (torch.float16, torch.float32): + logger.warning( + "The TPU backend currently does not support %s. " + "Using bfloat16 instead.", vllm_config.model_config.dtype) + vllm_config.model_config.dtype = torch.bfloat16 + parallel_config = vllm_config.parallel_config scheduler_config = vllm_config.scheduler_config if parallel_config.worker_cls == "auto": diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 031abdc05d517..c34b5b58672e7 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -78,17 +78,31 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: raise NotImplementedError( "XPU does not support speculative decoding") + if vllm_config.device_config is not None: + assert vllm_config.device_config.device_type == "xpu" + # check and update parallel config parallel_config = vllm_config.parallel_config - if (parallel_config.distributed_executor_backend is not None - and parallel_config.distributed_executor_backend != "ray"): + if parallel_config.worker_cls == "auto": + parallel_config.worker_cls = "vllm.worker.xpu_worker.XPUWorker" + + if parallel_config.distributed_executor_backend is None: + parallel_config.distributed_executor_backend = "ray" + elif parallel_config.distributed_executor_backend == "mp": + # FIXME(kunshang): + # spawn needs calling `if __name__ == '__main__':`` + # fork is not supported for xpu start new process. + logger.error( + "Both start methods (spawn and fork) have issue " + "on XPU if you use mp backend, setting it to ray instead.") + parallel_config.distributed_executor_backend = "ray" + + elif parallel_config.distributed_executor_backend != "ray": logger.warning( "%s is not supported on XPU, fallback to ray distributed" " executor backend.", parallel_config.distributed_executor_backend) parallel_config.distributed_executor_backend = "ray" - if parallel_config.worker_cls == "auto": - parallel_config.worker_cls = "vllm.worker.xpu_worker.XPUWorker" @classmethod def is_pin_memory_available(cls): diff --git a/vllm/spec_decode/medusa_worker.py b/vllm/spec_decode/medusa_worker.py index 1ab691a7ef047..21a58fc426275 100644 --- a/vllm/spec_decode/medusa_worker.py +++ b/vllm/spec_decode/medusa_worker.py @@ -9,17 +9,15 @@ from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase from vllm.spec_decode.top1_proposer import Top1Proposer -from vllm.worker.worker_base import WorkerWrapperBase +from vllm.worker.worker_base import DelegateWorkerBase -class MedusaWorker(NonLLMProposerWorkerBase, WorkerWrapperBase): +class MedusaWorker(NonLLMProposerWorkerBase, DelegateWorkerBase): """Worker for Medusa. """ def __init__(self, *args, **kwargs): - super().__init__(kwargs.get("vllm_config")) - self.init_worker(*args, **kwargs) - + DelegateWorkerBase.__init__(self, *args, **kwargs) # Lazy initialization list. self._proposer: Top1Proposer diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 676ac5eb3609d..32197f8cc8f2f 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -16,10 +16,10 @@ SpeculativeProposer) from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase from vllm.spec_decode.top1_proposer import Top1Proposer -from vllm.worker.worker_base import WorkerWrapperBase +from vllm.worker.worker_base import DelegateWorkerBase -class MultiStepWorker(ProposerWorkerBase, WorkerWrapperBase): +class MultiStepWorker(ProposerWorkerBase, DelegateWorkerBase): """The MultiStepWorker is equivalent to a Worker except that it allows multiple forward passes in a single call, assuming the scheduler has allocated enough space to store the additional KV. This reduces overhead @@ -32,15 +32,12 @@ class MultiStepWorker(ProposerWorkerBase, WorkerWrapperBase): """ def __init__(self, *args, **kwargs): - super().__init__(kwargs.get("vllm_config")) - self.init_worker(*args, **kwargs) - + DelegateWorkerBase.__init__(self, *args, **kwargs) # Lazy initialization list. self._proposer: SpeculativeProposer def init_device(self) -> None: self.worker.init_device() - self._proposer = Top1Proposer( weakref.proxy(self), # type: ignore[arg-type] self.device, @@ -56,18 +53,6 @@ def set_should_modify_greedy_probs_inplace(self) -> None: self.model_runner.model.sampler.should_modify_greedy_probs_inplace = ( True) - def determine_num_available_blocks(self) -> Tuple[int, int]: - return self.worker.determine_num_available_blocks() - - def get_cache_block_size_bytes(self) -> int: - return self.worker.get_cache_block_size_bytes() - - def initialize_cache(self, *args, **kwargs) -> None: - self.worker.initialize_cache(*args, **kwargs) - - def execute_model(self, *args, **kwargs) -> List[SamplerOutput]: - return self.worker.execute_model(*args, **kwargs) - @torch.inference_mode() def sampler_output( self, diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index e369da1a70c23..540d118d65ecb 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -40,8 +40,8 @@ get_all_num_logprobs, get_sampled_token_logprobs, nvtx_range, split_batch_by_proposal_len) -from vllm.worker.worker_base import (LoraNotSupportedWorkerBase, WorkerBase, - WorkerWrapperBase) +from vllm.utils import resolve_obj_by_qualname +from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase logger = init_logger(__name__) @@ -64,8 +64,9 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": target_worker_config = copy.deepcopy(vllm_config) target_worker_config.parallel_config.worker_cls =\ target_worker_config.parallel_config.sd_worker_cls - target_worker = WorkerWrapperBase(vllm_config=target_worker_config) - target_worker.init_worker(*args, **kwargs) + cls = resolve_obj_by_qualname( + target_worker_config.parallel_config.worker_cls) + target_worker = cls(*args, **kwargs) # Set the disable_logprobs variable in the TargetModelRunner instance # as per its value specified in the SpeculativeConfig. target_worker.model_runner.disable_logprobs =\ diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index 5d74d4b01f500..7c17f60510ae1 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -14,8 +14,9 @@ def get_class(vllm_config: VllmConfig) -> Type["Executor"]: distributed_executor_backend = ( vllm_config.parallel_config.distributed_executor_backend) if distributed_executor_backend == "ray": - from vllm.v1.executor.ray_executor import RayExecutor - executor_class = RayExecutor + from vllm.executor.ray_distributed_executor import ( # noqa + RayDistributedExecutor) + executor_class = RayDistributedExecutor elif distributed_executor_backend == "mp": from vllm.v1.executor.multiproc_executor import MultiprocExecutor executor_class = MultiprocExecutor diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 41e6abbd67956..cee0fcc0bad68 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -246,9 +246,18 @@ def __init__( ready_path: str, ): self.rank = rank - wrapper = WorkerWrapperBase(vllm_config=vllm_config) - wrapper.init_worker(vllm_config, local_rank, rank, - distributed_init_method) + wrapper = WorkerWrapperBase(vllm_config=vllm_config, rank=rank) + # TODO: move `init_worker` to executor level as a collective rpc call + all_kwargs: List[Dict] = [ + {} for _ in range(vllm_config.parallel_config.world_size) + ] + all_kwargs[rank] = { + "vllm_config": vllm_config, + "local_rank": local_rank, + "rank": rank, + "distributed_init_method": distributed_init_method, + } + wrapper.init_worker(all_kwargs) self.worker = wrapper.worker pid = os.getpid() @@ -270,7 +279,7 @@ def __init__( ready_socket.send_string(WorkerProc.READY_STR) ready_socket.send(payload) - self.worker.initialize() + self.worker.init_device() self.worker.load_model() @staticmethod diff --git a/vllm/v1/executor/uniproc_executor.py b/vllm/v1/executor/uniproc_executor.py index be058318de58b..a9adc0114f76d 100644 --- a/vllm/v1/executor/uniproc_executor.py +++ b/vllm/v1/executor/uniproc_executor.py @@ -27,7 +27,7 @@ def __init__(self, vllm_config: VllmConfig) -> None: self.observability_config = vllm_config.observability_config self.worker: Worker = self._create_worker() - self.worker.initialize() + self.worker.init_device() self.worker.load_model() def _create_worker( diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index e83bce4283555..e6feaee972a35 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -33,6 +33,7 @@ def __init__( local_rank: int, rank: int, distributed_init_method: str, + is_driver_worker: bool = False, ): # TODO: use WorkerBase.__init__(self, vllm_config=vllm_config) @@ -75,7 +76,7 @@ def __init__( else: self.profiler = None - def initialize(self): + def init_device(self): if self.device_config.device.type == "cuda": # torch.distributed.all_reduce does not free the input tensor until # the synchronization point. This causes the memory usage to grow diff --git a/vllm/worker/hpu_worker.py b/vllm/worker/hpu_worker.py index 8b2d8aaed2803..9401241073c7d 100644 --- a/vllm/worker/hpu_worker.py +++ b/vllm/worker/hpu_worker.py @@ -2,6 +2,7 @@ # Copyright (C) 2024 Habana Labs, Ltd. an Intel Company ############################################################################### +import contextlib import gc import os from typing import List, Optional, Set, Tuple, Type @@ -18,6 +19,7 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import ExecuteModelRequest from vllm.utils import bind_kv_cache @@ -124,6 +126,70 @@ def init_device(self) -> None: def load_model(self): self.model_runner.load_model() + def execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None, + ) -> Optional[List[SamplerOutput]]: + assert execute_model_req is not None + # VLLM_HPU_LOG_STEP_GRAPH_COMPILATION - will log graph compilations per engine step, only when there was any - highly recommended to use alongside PT_HPU_METRICS_GC_DETAILS! # noqa:E501 + # VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL - will log graph compilations per engine step, always, even if there were none # noqa:E501 + # VLLM_HPU_LOG_STEP_CPU_FALLBACKS - will log cpu fallbacks per engine step, only when there was any # noqa:E501 + # VLLM_HPU_LOG_STEP_CPU_FALLBACKS_ALL - will log cpu fallbacks per engine step, always, even if there were none # noqa:E501 + log_graph_compilation_all = os.environ.get( + 'VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL', '0') != '0' + log_graph_compilation = os.environ.get( + 'VLLM_HPU_LOG_STEP_GRAPH_COMPILATION', + '0') != '0' or log_graph_compilation_all + log_cpu_fallbacks_all = os.environ.get( + 'VLLM_HPU_LOG_STEP_CPU_FALLBACKS_ALL', '0') != '0' + log_cpu_fallbacks = os.environ.get('VLLM_HPU_LOG_STEP_CPU_FALLBACKS', + '0') != '0' or log_cpu_fallbacks_all + if log_graph_compilation or log_cpu_fallbacks: + from habana_frameworks.torch.hpu.metrics import metric_localcontext + seq_group_metadata_list = execute_model_req.seq_group_metadata_list + is_prompt = any([ + seq_group_metadata.is_prompt + for seq_group_metadata in seq_group_metadata_list + ]) + max_context_len = max([ + max([ + len(v.prompt_token_ids) + len(v.output_token_ids) + for v in seq_group_metadata.seq_data.values() + ]) for seq_group_metadata in seq_group_metadata_list + ]) # whoa, that's some spicy stuff right here + max_num_blocks = ( + (max_context_len - 1) // self.cache_config.block_size) + 1 + input_stats = (f'is_prompt: {is_prompt}, ' + f'num_seqs: {len(seq_group_metadata_list)}, ' + f'max_context_len: {max_context_len}, ' + f'max_num_blocks {max_num_blocks}') + gc_ctx = metric_localcontext( + "graph_compilation" + ) if log_graph_compilation else contextlib.nullcontext() + cpu_fallback_ctx = metric_localcontext( + "cpu_fallback" + ) if log_cpu_fallbacks else contextlib.nullcontext() + with gc_ctx as gc_local_metric, \ + cpu_fallback_ctx as cpu_fallback_local_metric: + output = LocalOrDistributedWorkerBase.execute_model( + self, execute_model_req) + if (log_graph_compilation and gc_local_metric.stats()[0][1] > 0 + ) or log_graph_compilation_all: + msg = ("VLLM_HPU_STEP_GRAPH_COMPILATION: " + f"{gc_local_metric.stats()}, {input_stats}") + logger.warning(msg) + if (log_cpu_fallbacks and cpu_fallback_local_metric.stats()[0][1] > + 0) or log_cpu_fallbacks_all: + msg = ("VLLM_HPU_STEP_CPU_FALLBACK: " + f"{cpu_fallback_local_metric.stats()}, {input_stats}") + logger.warning(msg) + + return output + + output = LocalOrDistributedWorkerBase.execute_model( + self, execute_model_req) + return output + @torch.inference_mode() def determine_num_available_blocks(self) -> Tuple[int, int]: """Profiles the peak memory usage of the model to determine how many diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py index 3f6269684ac93..e02c72faace70 100644 --- a/vllm/worker/neuron_worker.py +++ b/vllm/worker/neuron_worker.py @@ -8,6 +8,7 @@ from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.model_executor import set_random_seed +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import ExecuteModelRequest from vllm.worker.neuron_model_runner import NeuronModelRunner from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, @@ -25,6 +26,7 @@ def __init__( local_rank: int, rank: int, distributed_init_method: str, + is_driver_worker: bool = True, ) -> None: WorkerBase.__init__(self, vllm_config=vllm_config) self.local_rank = local_rank @@ -37,7 +39,22 @@ def __init__( self.model_runner: NeuronModelRunner = NeuronModelRunner( vllm_config=vllm_config) - self.is_driver_worker = True + self.is_driver_worker = is_driver_worker + + def execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None, + ) -> Optional[List[SamplerOutput]]: + assert execute_model_req is not None + assert (not execute_model_req.blocks_to_swap_in + and not execute_model_req.blocks_to_swap_out + and not execute_model_req.blocks_to_copy), ( + "Cache operations are not supported for Neuron backend.") + assert execute_model_req.num_lookahead_slots == 0, ( + "lookahead not supported for Neuron backend.") + output = LocalOrDistributedWorkerBase.execute_model( + self, execute_model_req) + return output def init_device(self) -> None: self.init_distributed_environment() @@ -103,13 +120,14 @@ def get_cache_block_size_bytes(self) -> int: def init_distributed_environment(self): """Neuron uses transformers-neuronx for tensor parallelism. - - vLLM still needs the environment inited when TP/PP > 1 + It has only one process to control multiple devices. + vLLM still needs the environment initialized when TP/PP > 1, + so we initialize a distributed environment with one process. """ init_distributed_environment( world_size=1, - rank=self.rank, - local_rank=self.local_rank, + rank=0, + local_rank=0, distributed_init_method=self.distributed_init_method, backend="gloo", ) diff --git a/vllm/worker/openvino_worker.py b/vllm/worker/openvino_worker.py index 3482073566215..50a155d22c666 100644 --- a/vllm/worker/openvino_worker.py +++ b/vllm/worker/openvino_worker.py @@ -211,16 +211,14 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase): def __init__( self, - ov_core: ov.Core, vllm_config: VllmConfig, local_rank: int, rank: int, distributed_init_method: str, - kv_cache_dtype: Optional[ov.Type] = ov.Type.undefined, is_driver_worker: bool = False, ) -> None: - self.ov_core = ov_core WorkerBase.__init__(self, vllm_config) + self.ov_core = ov.Core() self.parallel_config.rank = rank self.local_rank = local_rank self.rank = rank @@ -237,7 +235,7 @@ def __init__( self.model_runner = OpenVINOModelRunner( self.ov_core, vllm_config=self.vllm_config, - kv_cache_dtype=kv_cache_dtype, + kv_cache_dtype=self.vllm_config.cache_config.cache_dtype, is_driver_worker=is_driver_worker, ) # Uninitialized cache engine. Will be initialized by diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index a835718e1db19..7c14b8344b49e 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -88,7 +88,6 @@ def start_worker_execution_loop(self) -> None: if output is None: return None - @abstractmethod def execute_model( self, execute_model_req: Optional[ExecuteModelRequest] = None @@ -119,6 +118,58 @@ def list_loras(self) -> Set[int]: raise NotImplementedError +class DelegateWorkerBase(WorkerBase): + """ + A class that delegates all methods to another WorkerBase instance. This is + useful for creating a WorkerBase that wraps another WorkerBase instance, + e.g. speculative decoding. + """ + worker: WorkerBase + + def __init__( + self, + *args, + **kwargs, + ) -> None: + vllm_config: VllmConfig = kwargs.get("vllm_config") + cls = resolve_obj_by_qualname(vllm_config.parallel_config.worker_cls) + self.worker = cls(*args, **kwargs) + + def init_device(self) -> None: + self.worker.init_device() + + def determine_num_available_blocks(self) -> Tuple[int, int]: + return self.worker.determine_num_available_blocks() + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + self.worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) + + def execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> Optional[List[SamplerOutput]]: + return self.worker.execute_model(execute_model_req) + + def get_cache_block_size_bytes(self) -> int: + return self.worker.get_cache_block_size_bytes() + + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.worker.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + return self.worker.remove_lora(lora_id) + + def pin_lora(self, lora_id: int) -> bool: + return self.worker.pin_lora(lora_id) + + def list_loras(self) -> Set[int]: + return self.worker.list_loras() + + def __getattr__(self, attr): + return getattr(self.worker, attr) + + class LoraNotSupportedWorkerBase(WorkerBase): """Partial implementation of WorkerBase that raises exceptions when LoRA methods are invoked. @@ -419,17 +470,31 @@ class WorkerWrapperBase: def __init__( self, vllm_config: VllmConfig, + rank: int = 0, ) -> None: + self.rank = rank self.vllm_config = vllm_config - trust_remote_code = vllm_config.model_config.trust_remote_code self.worker: Optional[WorkerBase] = None - if trust_remote_code: - # note: lazy import to avoid importing torch before initializing - from vllm.utils import init_cached_hf_modules - init_cached_hf_modules() + if vllm_config.model_config is not None: + # it can be None in tests + trust_remote_code = vllm_config.model_config.trust_remote_code + if trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() + + def adjust_rank(self, rank_mapping: Dict[int, int]) -> None: + """ + Adjust the rank based on the given mapping. + It is only used during the initialization of the executor, + to adjust the rank of workers after we create all workers. + """ + if self.rank in rank_mapping: + self.rank = rank_mapping[self.rank] - @staticmethod - def update_environment_variables(envs: Dict[str, str]) -> None: + def update_environment_variables(self, envs_list: List[Dict[str, + str]]) -> None: + envs = envs_list[self.rank] key = 'CUDA_VISIBLE_DEVICES' if key in envs and key in os.environ: # overwriting CUDA_VISIBLE_DEVICES is desired behavior @@ -437,11 +502,12 @@ def update_environment_variables(envs: Dict[str, str]) -> None: del os.environ[key] update_environment_variables(envs) - def init_worker(self, *args, **kwargs): + def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None: """ Here we inject some common logic before initializing the worker. Arguments are passed to the worker class constructor. """ + kwargs = all_kwargs[self.rank] enable_trace_function_call_for_thread(self.vllm_config) # see https://github.com/NVIDIA/nccl/issues/1234 @@ -452,7 +518,7 @@ def init_worker(self, *args, **kwargs): worker_class = resolve_obj_by_qualname( self.vllm_config.parallel_config.worker_cls) - self.worker = worker_class(*args, **kwargs) + self.worker = worker_class(**kwargs) assert self.worker is not None def execute_method(self, method: str, *args, **kwargs):