Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[platforms] enable platform plugins #11602

Merged
merged 31 commits into from
Dec 30, 2024
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,12 @@ steps:
source_file_dependencies:
- vllm/
commands:
- pip install -e ./plugins/vllm_add_dummy_model
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py
- pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
- pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process
- pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py
- pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process
- pytest -v -s entrypoints/test_chat_utils.py
- pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests

Expand Down Expand Up @@ -333,8 +331,6 @@ steps:
- vllm/
- tests/models
commands:
- pip install -e ./plugins/vllm_add_dummy_model
- pytest -v -s models/test_oot_registration.py # it needs a clean process
- pytest -v -s models/test_registry.py
- pytest -v -s models/test_initialization.py

Expand Down Expand Up @@ -469,11 +465,28 @@ steps:
- pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m 'distributed(num_gpus=2)'
- pytest models/decoder_only/vision_language/test_models.py -v -s -m 'distributed(num_gpus=2)'
- pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
- pip install -e ./plugins/vllm_add_dummy_model
- pytest -v -s distributed/test_distributed_oot.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/disagg_test.py

- label: Plugin Tests (2 GPUs) # 40min
working_dir: "/vllm-workspace/tests"
num_gpus: 2
fast_check: true
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I make this new test fast check, because I assume many PRs will break it easily by importing current_platform too early. Adding it as fast check gives people quick signal that they should lazy import current_platform.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think the problem now is that it's hard for users to know when the current_platform can be imported globally or lazily in future develop work. I know that all the action early than current_platform initiation should be lazy import. while for other developers, once they hit this kind of problem, it's a little complex and hard for them to debug. Any note or check can be added? Otherwise the PR LGTM.

Copy link
Member

@DarkLight1337 DarkLight1337 Dec 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer the approach in #11222 where we essentially lazily evaluate the attributes of current_platform (rather than current_platform itself) even if it's imported early.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have to point out that my PR #11222 has a disadvantage that current_platform should not be called directly in global namespace, for example https://github.com/vllm-project/vllm/pull/11222/files#diff-3bb4c705122f0ac754ce571780b1d75d5a8d444b2c99581385c1adfaeb8562e2R302

The __getattr__ design in this PR is a nice solution to fix the global namespace usage. https://github.com/vllm-project/vllm/pull/11602/files#diff-b353a4069aa142efe66166225c15f17b977a93fd3fc1d64482e18da127e420b8R195

So maybe another way is to combine them together, so that current_platform can be imported early and can be used in global namespace as well.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As @wangxiyuan pointed out, #11222 has a silent correctness issue. There are some code using @current_platform.inference_mode() to decorate functions, e.g.

@current_platform.inference_mode()

and in that PR, it directly resolves to Platform.inference_mode, even if we resolve the platform later to be a specific platform.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case, I suggest only exporting a get_current_platform method to indicate to other developers that current_platform is lazy evaluated.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The goal of this PR (platform plugins), is to make sure, when a platform plugin is registered and should be activated, all usage of current_platform should resolve to that platform.

To assure this goal, I added the test https://github.com/vllm-project/vllm/pull/11602/files#diff-b147065bfd9d20787e4c9d353c862732e9d1d797d5297ca36f6b131914d2cec1 , and also make sure the current_platform is resolved only once.

while for other developers, once they hit this kind of problem, it's a little complex and hard for them to debug

if they import current_platform too early, then this test will fail, and the error message will tell them which line to blame.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DarkLight1337 I think get_current_platform still cannot solve the @current_platform.inference_mode() issue because it is indeed used during module import time. To make it work with platform plugins, we have to change the code to lazy import anyway.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I'm just suggesting the name change to make it more obvious that current_platform isn't available immediately.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if they import current_platform too early, then this test will fail, and the error message will tell them which line to blame.

I'm OK with this. Looks clear enough

source_file_dependencies:
- vllm/plugins/
- tests/plugins/
commands:
# begin platform plugin tests, all the code in-between runs on dummy platform
- pip install -e ./plugins/vllm_add_dummy_platform
- pytest -v -s plugins_tests/test_platform_plugins.py
- pip uninstall vllm_add_dummy_platform -y
# end platform plugin tests
# other tests continue here:
- pip install -e ./plugins/vllm_add_dummy_model
- pytest -v -s distributed/test_distributed_oot.py
- pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process
- pytest -v -s models/test_oot_registration.py # it needs a clean process

- label: Multi-step Tests (4 GPUs) # 36min
working_dir: "/vllm-workspace/tests"
num_gpus: 4
Expand Down
6 changes: 4 additions & 2 deletions docs/source/design/plugin_system.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,11 @@ Every plugin has three parts:
2. **Plugin name**: The name of the plugin. This is the value in the dictionary of the `entry_points` dictionary. In the example above, the plugin name is `register_dummy_model`. Plugins can be filtered by their names using the `VLLM_PLUGINS` environment variable. To load only a specific plugin, set `VLLM_PLUGINS` to the plugin name.
3. **Plugin value**: The fully qualified name of the function to register in the plugin system. In the example above, the plugin value is `vllm_add_dummy_model:register`, which refers to a function named `register` in the `vllm_add_dummy_model` module.

## What Can Plugins Do?
## Types of supported plugins

Currently, the primary use case for plugins is to register custom, out-of-the-tree models into vLLM. This is done by calling `ModelRegistry.register_model` to register the model. In the future, the plugin system may be extended to support more features, such as swapping in custom implementations for certain classes in vLLM.
- **General plugins** (with group name `vllm.general_plugins`): The primary use case for these plugins is to register custom, out-of-the-tree models into vLLM. This is done by calling `ModelRegistry.register_model` to register the model inside the plugin function.

- **Platform plugins** (with group name `vllm.platform_plugins`): The primary use case for these plugins is to register custom, out-of-the-tree platforms into vLLM. The plugin function should return `None` when the platform is not supported in the current environment, or the platform class's fully qualified name when the platform is supported.

## Guidelines for Writing Plugins

Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
to_enc_dec_tuple_list, zip_enc_dec_prompts)
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.platforms import current_platform
from vllm.sampling_params import BeamSearchParams
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
identity)
Expand Down Expand Up @@ -242,6 +241,7 @@ def video_assets() -> _VideoAssets:
class HfRunner:

def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
from vllm.platforms import current_platform
if x is None or isinstance(x, (bool, )):
return x

Expand Down
11 changes: 11 additions & 0 deletions tests/plugins/vllm_add_dummy_platform/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from setuptools import setup

setup(
name='vllm_add_dummy_platform',
version='0.1',
packages=['vllm_add_dummy_platform'],
entry_points={
'vllm.platform_plugins': [
"dummy_platform_plugin = vllm_add_dummy_platform:dummy_platform_plugin" # noqa
]
})
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from typing import Optional


def dummy_platform_plugin() -> Optional[str]:
return "vllm_add_dummy_platform.dummy_platform.DummyPlatform"
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from vllm.platforms.cuda import CudaPlatform


class DummyPlatform(CudaPlatform):
device_name = "DummyDevice"
16 changes: 16 additions & 0 deletions tests/plugins_tests/test_platform_plugins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
def test_platform_plugins():
# simulate workload by running an example
import runpy
current_file = __file__
import os
example_file = os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(current_file))),
"examples", "offline_inference.py")
runpy.run_path(example_file)

# check if the plugin is loaded correctly
from vllm.platforms import _init_trace, current_platform
assert current_platform.device_name == "DummyDevice", (
f"Expected DummyDevice, got {current_platform.device_name}, "
"possibly because current_platform is imported before the plugin"
f" is loaded. The first import:\n{_init_trace}")
15 changes: 12 additions & 3 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS,
get_quantization_config)
from vllm.model_executor.models import ModelRegistry
from vllm.platforms import current_platform, interface
from vllm.platforms import CpuArchEnum
from vllm.tracing import is_otel_available, otel_import_error_traceback
from vllm.transformers_utils.config import (
ConfigFormat, get_config, get_hf_image_processor_config,
Expand Down Expand Up @@ -343,6 +343,7 @@ def __init__(self,
self.is_hybrid = self._init_is_hybrid()
self.has_inner_state = self._init_has_inner_state()

from vllm.platforms import current_platform
if current_platform.is_neuron():
self.override_neuron_config = override_neuron_config
else:
Expand Down Expand Up @@ -583,6 +584,7 @@ def _verify_quantization(self) -> None:
raise ValueError(
f"Unknown quantization method: {self.quantization}. Must "
f"be one of {supported_quantization}.")
from vllm.platforms import current_platform
current_platform.verify_quantization(self.quantization)
if self.quantization not in optimized_quantization_methods:
logger.warning(
Expand Down Expand Up @@ -638,6 +640,7 @@ def verify_async_output_proc(self, parallel_config, speculative_config,

# Reminder: Please update docs/source/usage/compatibility_matrix.md
# If the feature combo become valid
from vllm.platforms import current_platform
if not current_platform.is_async_output_supported(self.enforce_eager):
logger.warning(
"Async output processing is not supported on the "
Expand Down Expand Up @@ -1006,6 +1009,7 @@ def _verify_args(self) -> None:
raise ValueError(
"GPU memory utilization must be less than 1.0. Got "
f"{self.gpu_memory_utilization}.")
from vllm.platforms import current_platform
if (current_platform.is_cuda() and self.block_size is not None
and self.block_size > 32):
raise ValueError("CUDA Paged Attention kernel only supports "
Expand Down Expand Up @@ -1273,6 +1277,7 @@ def __post_init__(self) -> None:
f"distributed executor backend "
f"'{self.distributed_executor_backend}'.")
ray_only_devices = ["tpu", "hpu"]
from vllm.platforms import current_platform
if (current_platform.device_type in ray_only_devices
and self.world_size > 1):
if self.distributed_executor_backend is None:
Expand Down Expand Up @@ -1321,7 +1326,7 @@ def use_ray(self) -> bool:
def _verify_args(self) -> None:
# Lazy import to avoid circular import
from vllm.executor.executor_base import ExecutorBase

from vllm.platforms import current_platform
if self.distributed_executor_backend not in (
"ray", "mp", None) and not (isinstance(
self.distributed_executor_backend, type) and issubclass(
Expand Down Expand Up @@ -1522,6 +1527,7 @@ def compute_hash(self) -> str:
def __init__(self, device: str = "auto") -> None:
if device == "auto":
# Automated device type detection
from vllm.platforms import current_platform
self.device_type = current_platform.device_type
if not self.device_type:
raise RuntimeError("Failed to infer device type")
Expand Down Expand Up @@ -2235,9 +2241,10 @@ def _get_and_verify_dtype(
else:
torch_dtype = config_dtype

from vllm.platforms import current_platform
if (current_platform.is_cpu()
and current_platform.get_cpu_architecture()
== interface.CpuArchEnum.POWERPC
== CpuArchEnum.POWERPC
and (config_dtype == torch.float16
or config_dtype == torch.float32)):
logger.info(
Expand Down Expand Up @@ -3052,6 +3059,7 @@ def _get_quantization_config(
model_config: ModelConfig,
load_config: LoadConfig) -> Optional[QuantizationConfig]:
"""Get the quantization config."""
from vllm.platforms import current_platform
if model_config.quantization is not None:
from vllm.model_executor.model_loader.weight_utils import (
get_quant_config)
Expand Down Expand Up @@ -3114,6 +3122,7 @@ def __post_init__(self):
self.quant_config = VllmConfig._get_quantization_config(
self.model_config, self.load_config)

from vllm.platforms import current_platform
if self.scheduler_config is not None and \
self.model_config is not None and \
self.scheduler_config.chunked_prefill_enabled and \
Expand Down
3 changes: 2 additions & 1 deletion vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
import vllm.envs as envs
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op, supports_custom_op

if TYPE_CHECKING:
Expand Down Expand Up @@ -194,6 +193,7 @@ def __init__(
assert self.cpu_group is not None
assert self.device_group is not None

from vllm.platforms import current_platform
if current_platform.is_cuda_alike():
self.device = torch.device(f"cuda:{local_rank}")
else:
Expand Down Expand Up @@ -1188,6 +1188,7 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
import ray # Lazy import Ray
ray.shutdown()
gc.collect()
from vllm.platforms import current_platform
if not current_platform.is_cpu():
torch.cuda.empty_cache()

Expand Down
2 changes: 1 addition & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.platforms import current_platform
from vllm.transformers_utils.utils import check_gguf_file
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, StoreBoolean
Expand Down Expand Up @@ -1094,6 +1093,7 @@ def create_engine_config(self,
use_sliding_window = (model_config.get_sliding_window()
is not None)
use_spec_decode = self.speculative_model is not None
from vllm.platforms import current_platform
if (is_gpu and not use_sliding_window and not use_spec_decode
and not self.enable_lora
and not self.enable_prompt_adapter
Expand Down
2 changes: 1 addition & 1 deletion vllm/executor/ray_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from vllm.config import ParallelConfig
from vllm.executor.msgspec_utils import decode_hook, encode_hook
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.utils import get_ip
from vllm.worker.worker_base import WorkerWrapperBase
Expand Down Expand Up @@ -229,6 +228,7 @@ def initialize_ray_cluster(
the default Ray cluster address.
"""
assert_ray_available()
from vllm.platforms import current_platform

# Connect to a ray cluster.
if current_platform.is_rocm() or current_platform.is_xpu():
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/guided_decoding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from vllm.model_executor.guided_decoding.utils import (
convert_lark_to_gbnf, grammar_is_likely_lark,
has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features)
from vllm.platforms import CpuArchEnum, current_platform
from vllm.platforms import CpuArchEnum

if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
Expand Down Expand Up @@ -39,6 +39,7 @@ def maybe_backend_fallback(

if guided_params.backend == "xgrammar":
# xgrammar only has x86 wheels for linux, fallback to outlines
from vllm.platforms import current_platform
if current_platform.get_cpu_architecture() is not CpuArchEnum.X86:
logger.warning("xgrammar is only supported on x86 CPUs. "
"Falling back to use outlines instead.")
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import torch.nn as nn

from vllm.logger import init_logger
from vllm.platforms import current_platform

from .interfaces import (has_inner_state, is_attention_free, is_hybrid,
supports_cross_encoding, supports_multimodal,
Expand Down Expand Up @@ -272,6 +271,7 @@ def _try_load_model_cls(
model_arch: str,
model: _BaseRegisteredModel,
) -> Optional[Type[nn.Module]]:
from vllm.platforms import current_platform
current_platform.verify_model_arch(model_arch)
try:
return model.load_model_cls()
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@

import torch

from vllm.platforms import current_platform


def set_random_seed(seed: int) -> None:
from vllm.platforms import current_platform
current_platform.seed_everything(seed)


Expand Down Expand Up @@ -38,6 +37,7 @@ def set_weight_attrs(
# This sometimes causes OOM errors during model loading. To avoid this,
# we sync the param tensor after its weight loader is called.
# TODO(woosuk): Remove this hack once we have a better solution.
from vllm.platforms import current_platform
if current_platform.is_tpu() and key == "weight_loader":
value = _make_synced_weight_loader(value)
setattr(weight, key, value)
Expand Down
Loading
Loading