From bab73d1f81ed8d0f8c5b301cec79cb1ee752ee15 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 21 Jan 2025 15:42:21 +0800 Subject: [PATCH 01/14] separate builder input and builder prepare Signed-off-by: youkaichao --- vllm/worker/cpu_model_runner.py | 23 +++++++++++++---------- vllm/worker/model_runner.py | 26 ++++++++++++++++---------- vllm/worker/model_runner_base.py | 5 +++++ 3 files changed, 34 insertions(+), 20 deletions(-) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index abbf6450ab7f6..c0045e647f104 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -140,13 +140,9 @@ def __init__(self, use_mrope: bool): self.input_mrope_positions: List[List[int]] = [[] for _ in range(3)] - def __init__(self, - runner: "CPUModelRunner", - finished_requests_ids: Optional[List[str]] = None) -> None: + def __init__(self, runner: "CPUModelRunner") -> None: super().__init__() - self.seq_group_metadata_list: List[SequenceGroupMetadata] = [] self.runner = runner - self.chunked_prefill = (runner.scheduler_config.chunked_prefill_enabled or runner.cache_config.enable_prefix_caching) self.model_input_cls = self.runner._model_input_cls @@ -156,11 +152,15 @@ def __init__(self, self.device = self.runner.device self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper self.enable_lora = self.runner.lora_config is not None - self.input_data = ModelInputForCPUBuilder.ModelInputData( - self.runner.model_config.uses_mrope) self.att_metadata_builder = self.runner.attn_backend.get_builder_cls()( self) + def prepare(self, + finished_requests_ids: Optional[List[str]] = None) -> None: + self.seq_group_metadata_list: List[SequenceGroupMetadata] = [] + self.input_data = ModelInputForCPUBuilder.ModelInputData( + self.runner.model_config.uses_mrope) + def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): self.seq_group_metadata_list.append(seq_group_metadata) @@ -431,6 +431,7 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]): """ _model_input_cls: Type[TModelInputForCPU] _builder_cls: Type[ModelInputForCPUBuilder] + builder: ModelInputForCPUBuilder def __init__( self, @@ -477,6 +478,8 @@ def __init__( # Set after load_model. self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None + self.builder = self._builder_cls(weakref.proxy(self)) + def load_model(self) -> None: self.model = get_model(vllm_config=self.vllm_config) @@ -522,10 +525,10 @@ def _prepare_model_input_tensors( metadata for possible additional steps, e.g., sampling. """ - builder = self._builder_cls(weakref.proxy(self), finished_requests_ids) - builder.set_seq_group_list(seq_group_metadata_list) + self.builder.prepare(finished_requests_ids) + self.builder.set_seq_group_list(seq_group_metadata_list) - return builder.build() # type: ignore + return self.builder.build() # type: ignore # sampler property will be used by spec_decode_worker @property diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index cb2ff0c934da3..1ccc68e7c33b2 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -457,14 +457,8 @@ def __init__(self, self.enable_prompt_adapter = (self.runner.prompt_adapter_config is not None) self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper - self.finished_requests_ids = finished_requests_ids self.decode_only = True - # Intermediate data (data in CPU before going to GPU) for - # the current sequence group. - self.inter_data_list: List[ - ModelInputForGPUBuilder.InterDataForSeqGroup] = [] - # Attention metadata inputs. self.attn_metadata_builder = self.attn_backend.make_metadata_builder( weakref.proxy(self)) @@ -479,6 +473,15 @@ def __init__(self, self.block_aligned_sliding_window = \ self.sliding_window_blocks * self.block_size + def prepare(self, + finished_requests_ids: Optional[List[str]] = None) -> None: + self.finished_requests_ids = finished_requests_ids + + # Intermediate data (data in CPU before going to GPU) for + # the current sequence group. + self.inter_data_list: List[ + ModelInputForGPUBuilder.InterDataForSeqGroup] = [] + def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, seq_group_metadata: SequenceGroupMetadata): """Compute context length, sequence length and tokens @@ -993,6 +996,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): """ _model_input_cls: Type[TModelInputForGPU] _builder_cls: Type[ModelInputForGPUBuilder] + builder: ModelInputForGPUBuilder def __init__( self, @@ -1093,6 +1097,8 @@ def __init__( SamplingMetadataCache() \ if self.parallel_config.pipeline_parallel_size == 1 else None + self.builder = self._builder_cls(weakref.proxy(self)) + def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) with DeviceMemoryProfiler() as m: @@ -1226,13 +1232,13 @@ def _prepare_model_input_tensors( If cuda graph is required, this API automatically pads inputs. """ - builder = self._builder_cls(weakref.proxy(self), finished_requests_ids) + self.builder.prepare(finished_requests_ids) for seq_group_metadata in seq_group_metadata_list: - builder.add_seq_group(seq_group_metadata) + self.builder.add_seq_group(seq_group_metadata) - builder.reset_cached_inter_data() + self.builder.reset_cached_inter_data() - return builder.build() # type: ignore + return self.builder.build() # type: ignore @contextmanager def set_in_profile_run(self): diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index acfd6d0b03f62..aef4bdcdd4bf9 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -200,6 +200,11 @@ class ModelRunnerInputBuilderBase(ABC, Generic[T]): """A builder to create ModelRunnerInputBase objects. """ + @abstractmethod + def prepare(self, + finished_requests_ids: Optional[List[str]] = None) -> None: + raise NotImplementedError + @abstractmethod def add_seq_group(self, seq_group_metadata): """TBA""" From 1100130ab64dd7da82b9d9b3d85067b88fe46774 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 21 Jan 2025 16:48:23 +0800 Subject: [PATCH 02/14] fix? Signed-off-by: youkaichao --- vllm/worker/model_runner.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 1ccc68e7c33b2..d5a9bbc79bb45 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -473,6 +473,8 @@ def __init__(self, self.block_aligned_sliding_window = \ self.sliding_window_blocks * self.block_size + self.prepare(finished_requests_ids) + def prepare(self, finished_requests_ids: Optional[List[str]] = None) -> None: self.finished_requests_ids = finished_requests_ids From 14a30db5dd092f94ff455019ad764cb821a34faf Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 21 Jan 2025 16:50:29 +0800 Subject: [PATCH 03/14] fix? Signed-off-by: youkaichao --- vllm/worker/cpu_model_runner.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index c0045e647f104..3d17bb2b3e748 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -140,7 +140,9 @@ def __init__(self, use_mrope: bool): self.input_mrope_positions: List[List[int]] = [[] for _ in range(3)] - def __init__(self, runner: "CPUModelRunner") -> None: + def __init__(self, + runner: "CPUModelRunner", + finished_requests_ids: Optional[List[str]] = None) -> None: super().__init__() self.runner = runner self.chunked_prefill = (runner.scheduler_config.chunked_prefill_enabled @@ -155,6 +157,8 @@ def __init__(self, runner: "CPUModelRunner") -> None: self.att_metadata_builder = self.runner.attn_backend.get_builder_cls()( self) + self.prepare(finished_requests_ids) + def prepare(self, finished_requests_ids: Optional[List[str]] = None) -> None: self.seq_group_metadata_list: List[SequenceGroupMetadata] = [] From 4f12b671b2ec742ea7c5079a072d386616ee1fbb Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 21 Jan 2025 16:54:45 +0800 Subject: [PATCH 04/14] fix? Signed-off-by: youkaichao --- vllm/worker/cpu_model_runner.py | 4 ++-- vllm/worker/model_runner.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 3d17bb2b3e748..62dc0bfbd0a24 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -154,8 +154,6 @@ def __init__(self, self.device = self.runner.device self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper self.enable_lora = self.runner.lora_config is not None - self.att_metadata_builder = self.runner.attn_backend.get_builder_cls()( - self) self.prepare(finished_requests_ids) @@ -164,6 +162,8 @@ def prepare(self, self.seq_group_metadata_list: List[SequenceGroupMetadata] = [] self.input_data = ModelInputForCPUBuilder.ModelInputData( self.runner.model_config.uses_mrope) + self.att_metadata_builder = self.runner.attn_backend.get_builder_cls()( + self) def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): self.seq_group_metadata_list.append(seq_group_metadata) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index d5a9bbc79bb45..432738d5a5a7a 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -459,10 +459,6 @@ def __init__(self, self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper self.decode_only = True - # Attention metadata inputs. - self.attn_metadata_builder = self.attn_backend.make_metadata_builder( - weakref.proxy(self)) - # Engine/Model configurations. self.chunked_prefill_enabled = ( self.scheduler_config is not None @@ -484,6 +480,10 @@ def prepare(self, self.inter_data_list: List[ ModelInputForGPUBuilder.InterDataForSeqGroup] = [] + # Attention metadata inputs. + self.attn_metadata_builder = self.attn_backend.make_metadata_builder( + weakref.proxy(self)) + def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, seq_group_metadata: SequenceGroupMetadata): """Compute context length, sequence length and tokens From a8983f0a160c988ca836ebdc9bb6f0452318fb11 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 21 Jan 2025 17:05:05 +0800 Subject: [PATCH 05/14] separate init and prepare Signed-off-by: youkaichao --- vllm/attention/backends/abstract.py | 6 ++++++ vllm/attention/backends/flash_attn.py | 11 ++++++----- vllm/attention/backends/flashinfer.py | 14 ++++++++------ vllm/attention/backends/placeholder_attn.py | 8 +++++--- vllm/attention/backends/torch_sdpa.py | 5 ++++- vllm/attention/backends/utils.py | 13 +++++++------ vllm/worker/cpu_model_runner.py | 4 ++-- vllm/worker/model_runner.py | 8 ++++---- 8 files changed, 42 insertions(+), 27 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index e6ddca69bf01b..01bd01f8fd96f 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -214,6 +214,12 @@ class AttentionMetadataBuilder(ABC, Generic[T]): @abstractmethod def __init__(self, input_builder: "ModelRunnerInputBuilderBase") -> None: + """Create the builder, remember some configuration and parameters.""" + raise NotImplementedError + + @abstractmethod + def prepare(self) -> None: + """Prepare for one batch.""" raise NotImplementedError @abstractmethod diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 40250ef08b595..60ed09d0cc44f 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -375,6 +375,12 @@ class FlashAttentionMetadataBuilder( AttentionMetadataBuilder[FlashAttentionMetadata]): def __init__(self, input_builder: "ModelInputForGPUBuilder"): + self.input_builder = input_builder + self.runner = input_builder.runner + self.sliding_window = input_builder.sliding_window + self.block_size = input_builder.block_size + + def prepare(self): self.slot_mapping: List[int] = [] self.prefill_seq_lens: List[int] = [] self.context_lens: List[int] = [] @@ -388,11 +394,6 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.num_decode_tokens = 0 self.has_prefix_cache_hit = False - self.input_builder = input_builder - self.runner = input_builder.runner - self.sliding_window = input_builder.sliding_window - self.block_size = input_builder.block_size - def _add_seq_group( self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", chunked_prefill_enabled: bool, prefix_cache_hit: bool): diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index b9cd805e81b45..b8ffbe6dd64dd 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -488,6 +488,14 @@ def advance_step(self, class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): def __init__(self, input_builder: "ModelInputForGPUBuilder"): + + self.input_builder = input_builder + self.runner = input_builder.runner + + self.sliding_window = input_builder.sliding_window + self.block_size = input_builder.block_size + + def prepare(self): self.slot_mapping: List[int] = [] self.prefill_seq_lens: List[int] = [] self.context_lens: List[int] = [] @@ -500,12 +508,6 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.num_prefill_tokens = 0 self.num_decode_tokens = 0 - self.input_builder = input_builder - self.runner = input_builder.runner - - self.sliding_window = input_builder.sliding_window - self.block_size = input_builder.block_size - # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout # for the precise definition of the following fields. # An example: diff --git a/vllm/attention/backends/placeholder_attn.py b/vllm/attention/backends/placeholder_attn.py index 534f79b3a60bf..37860494702cf 100644 --- a/vllm/attention/backends/placeholder_attn.py +++ b/vllm/attention/backends/placeholder_attn.py @@ -253,6 +253,11 @@ class PlaceholderAttentionMetadataBuilder( AttentionMetadataBuilder[PlaceholderAttentionMetadata]): def __init__(self, input_builder: "ModelInputForGPUBuilder"): + + self.input_builder = input_builder + self.runner = input_builder.runner + + def prepare(self): self.prefill_seq_lens: List[int] = [] self.context_lens: List[int] = [] self.curr_seq_lens: List[int] = [] @@ -263,9 +268,6 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.num_prefill_tokens = 0 self.num_decode_tokens = 0 - self.input_builder = input_builder - self.runner = input_builder.runner - def _add_seq_group( self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", chunked_prefill_enabled: bool): diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 7cd2049f0c0a5..8722d7376795a 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -282,7 +282,10 @@ class TorchSDPAMetadataBuilder(AttentionMetadataBuilder[TorchSDPAMetadata]): def __init__(self, input_builder: ModelInputForCPUBuilder) -> None: self.chunked_prefill = input_builder.chunked_prefill - self.input_data = input_builder.input_data + self.input_builder = input_builder + + def prepare(self): + self.input_data = self.input_builder.input_data def build(self, seq_lens: List[int], query_lens: List[int], cuda_graph_pad_size: int, batch_size: int) -> TorchSDPAMetadata: diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 56cc43430301f..3df7f54cbd8d2 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -122,6 +122,13 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): _metadata_cls: Type[TAttentionMetadata] def __init__(self, input_builder: "ModelInputForGPUBuilder"): + self.input_builder = input_builder + self.runner = input_builder.runner + + self.sliding_window = input_builder.sliding_window + self.block_size = input_builder.block_size + + def prepare(self): self.slot_mapping: List[int] = [] self.prefill_seq_lens: List[int] = [] self.context_lens: List[int] = [] @@ -134,12 +141,6 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.num_prefill_tokens = 0 self.num_decode_tokens = 0 - self.input_builder = input_builder - self.runner = input_builder.runner - - self.sliding_window = input_builder.sliding_window - self.block_size = input_builder.block_size - def _add_seq_group( self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", chunked_prefill_enabled: bool): diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 62dc0bfbd0a24..3d17bb2b3e748 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -154,6 +154,8 @@ def __init__(self, self.device = self.runner.device self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper self.enable_lora = self.runner.lora_config is not None + self.att_metadata_builder = self.runner.attn_backend.get_builder_cls()( + self) self.prepare(finished_requests_ids) @@ -162,8 +164,6 @@ def prepare(self, self.seq_group_metadata_list: List[SequenceGroupMetadata] = [] self.input_data = ModelInputForCPUBuilder.ModelInputData( self.runner.model_config.uses_mrope) - self.att_metadata_builder = self.runner.attn_backend.get_builder_cls()( - self) def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): self.seq_group_metadata_list.append(seq_group_metadata) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 432738d5a5a7a..d5a9bbc79bb45 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -459,6 +459,10 @@ def __init__(self, self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper self.decode_only = True + # Attention metadata inputs. + self.attn_metadata_builder = self.attn_backend.make_metadata_builder( + weakref.proxy(self)) + # Engine/Model configurations. self.chunked_prefill_enabled = ( self.scheduler_config is not None @@ -480,10 +484,6 @@ def prepare(self, self.inter_data_list: List[ ModelInputForGPUBuilder.InterDataForSeqGroup] = [] - # Attention metadata inputs. - self.attn_metadata_builder = self.attn_backend.make_metadata_builder( - weakref.proxy(self)) - def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, seq_group_metadata: SequenceGroupMetadata): """Compute context length, sequence length and tokens From 3d0d593064724b651647825700af4ef5e0a2b66d Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 21 Jan 2025 17:09:01 +0800 Subject: [PATCH 06/14] add prepare Signed-off-by: youkaichao --- vllm/worker/cpu_model_runner.py | 1 + vllm/worker/model_runner.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 3d17bb2b3e748..9f45387b55e19 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -164,6 +164,7 @@ def prepare(self, self.seq_group_metadata_list: List[SequenceGroupMetadata] = [] self.input_data = ModelInputForCPUBuilder.ModelInputData( self.runner.model_config.uses_mrope) + self.att_metadata_builder.prepare() def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): self.seq_group_metadata_list.append(seq_group_metadata) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index d5a9bbc79bb45..5825ad21e572b 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -484,6 +484,8 @@ def prepare(self, self.inter_data_list: List[ ModelInputForGPUBuilder.InterDataForSeqGroup] = [] + self.attn_metadata_builder.prepare() + def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, seq_group_metadata: SequenceGroupMetadata): """Compute context length, sequence length and tokens From 06db992fb89247330c9ad358a579c8f58d058259 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 21 Jan 2025 18:09:22 +0800 Subject: [PATCH 07/14] fix for multi-step Signed-off-by: youkaichao --- vllm/worker/cpu_model_runner.py | 3 ++- vllm/worker/model_runner.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 9f45387b55e19..7348096f90b22 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -483,7 +483,8 @@ def __init__( # Set after load_model. self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None - self.builder = self._builder_cls(weakref.proxy(self)) + if hasattr(self, "_builder_cls"): + self.builder = self._builder_cls(weakref.proxy(self)) def load_model(self) -> None: self.model = get_model(vllm_config=self.vllm_config) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 5825ad21e572b..ca4bba1b9b4f8 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1101,7 +1101,8 @@ def __init__( SamplingMetadataCache() \ if self.parallel_config.pipeline_parallel_size == 1 else None - self.builder = self._builder_cls(weakref.proxy(self)) + if hasattr(self, "_builder_cls"): + self.builder = self._builder_cls(weakref.proxy(self)) def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) From 2463a795c35d6b4716f312642fe5c67b922e9a49 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 22 Jan 2025 11:52:51 +0800 Subject: [PATCH 08/14] comment the case of _builder_cls Signed-off-by: youkaichao --- vllm/worker/cpu_model_runner.py | 1 + vllm/worker/model_runner.py | 1 + 2 files changed, 2 insertions(+) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 7348096f90b22..b0ad304d4c600 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -484,6 +484,7 @@ def __init__( self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None if hasattr(self, "_builder_cls"): + # multi-step model runner does not have `_builder_cls` self.builder = self._builder_cls(weakref.proxy(self)) def load_model(self) -> None: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index ca4bba1b9b4f8..5f4509ffbc3ee 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1102,6 +1102,7 @@ def __init__( if self.parallel_config.pipeline_parallel_size == 1 else None if hasattr(self, "_builder_cls"): + # multi-step model runner does not have `_builder_cls` self.builder = self._builder_cls(weakref.proxy(self)) def load_model(self) -> None: From 72488c2948bcc0e5f4bd519153b084058883775c Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 22 Jan 2025 12:01:20 +0800 Subject: [PATCH 09/14] comment the case of spec decode Signed-off-by: youkaichao --- vllm/attention/backends/abstract.py | 5 ----- vllm/worker/cpu_model_runner.py | 6 ++++-- vllm/worker/model_runner.py | 6 ++++-- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 01bd01f8fd96f..2efe142a17b69 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -65,11 +65,6 @@ def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata": def get_builder_cls() -> Type["AttentionMetadataBuilder"]: raise NotImplementedError - @classmethod - def make_metadata_builder(cls, *args, - **kwargs) -> "AttentionMetadataBuilder": - return cls.get_builder_cls()(*args, **kwargs) - @staticmethod @abstractmethod def get_kv_cache_shape( diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index b0ad304d4c600..8f41b7313efa3 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -154,8 +154,10 @@ def __init__(self, self.device = self.runner.device self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper self.enable_lora = self.runner.lora_config is not None - self.att_metadata_builder = self.runner.attn_backend.get_builder_cls()( - self) + if self.runner.attn_backend is not None: + # spec decode (e.g. Medusa) does not have atten backend + attn_backend = self.runner.attn_backend + self.att_metadata_builder = attn_backend.get_builder_cls()(self) self.prepare(finished_requests_ids) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 5f4509ffbc3ee..512f0fae5b08c 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -460,8 +460,10 @@ def __init__(self, self.decode_only = True # Attention metadata inputs. - self.attn_metadata_builder = self.attn_backend.make_metadata_builder( - weakref.proxy(self)) + if self.attn_backend is not None: + # spec decode (e.g. Medusa) does not have atten backend + self.attn_metadata_builder = self.attn_backend.get_builder_cls()( + weakref.proxy(self)) # Engine/Model configurations. self.chunked_prefill_enabled = ( From 0689affd43cb5e89a544af3d2d96775782324e51 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 22 Jan 2025 12:04:53 +0800 Subject: [PATCH 10/14] fix init Signed-off-by: youkaichao --- vllm/worker/cpu_model_runner.py | 2 -- vllm/worker/model_runner.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 8f41b7313efa3..4b429b67b36f8 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -159,8 +159,6 @@ def __init__(self, attn_backend = self.runner.attn_backend self.att_metadata_builder = attn_backend.get_builder_cls()(self) - self.prepare(finished_requests_ids) - def prepare(self, finished_requests_ids: Optional[List[str]] = None) -> None: self.seq_group_metadata_list: List[SequenceGroupMetadata] = [] diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 512f0fae5b08c..e311c14111d49 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -475,8 +475,6 @@ def __init__(self, self.block_aligned_sliding_window = \ self.sliding_window_blocks * self.block_size - self.prepare(finished_requests_ids) - def prepare(self, finished_requests_ids: Optional[List[str]] = None) -> None: self.finished_requests_ids = finished_requests_ids From acc0dc8640eae59a07b1ebbb8e86ba5d83eeb759 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 22 Jan 2025 12:23:26 +0800 Subject: [PATCH 11/14] fix xpu Signed-off-by: youkaichao --- vllm/worker/xpu_model_runner.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 25a2fea1e8eac..e7a7e4ef30980 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -121,6 +121,9 @@ def __init__(self, self.block_size = self.runner.block_size self.device = self.runner.device + def prepare(self) -> None: + pass + def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): self.seq_group_metadata_list.append(seq_group_metadata) From ce49aa9231aff54a079585a167df5302d2f95feb Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 22 Jan 2025 12:27:30 +0800 Subject: [PATCH 12/14] fix xpu Signed-off-by: youkaichao --- vllm/worker/xpu_model_runner.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index e7a7e4ef30980..b0235b7849e6f 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -121,7 +121,8 @@ def __init__(self, self.block_size = self.runner.block_size self.device = self.runner.device - def prepare(self) -> None: + def prepare(self, + finished_requests_ids: Optional[List[str]] = None) -> None: pass def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): From a677f83a5781b4f0c240f8f188a9da98b8e6db29 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 22 Jan 2025 12:29:50 +0800 Subject: [PATCH 13/14] fix xpu Signed-off-by: youkaichao --- vllm/worker/xpu_model_runner.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index b0235b7849e6f..d1b35f8b027e1 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -412,6 +412,8 @@ def __init__( SamplingMetadataCache() \ if self.parallel_config.pipeline_parallel_size == 1 else None + self.builder = self._builder_cls(weakref.proxy(self)) + def load_model(self) -> None: with DeviceMemoryProfiler() as m: self.model = get_model(vllm_config=self.vllm_config) @@ -521,7 +523,8 @@ def _prepare_model_input_tensors( metadata for possible additional steps, e.g., sampling. """ - builder = self._builder_cls(weakref.proxy(self), finished_requests_ids) + builder = self.builder + builder.prepare(finished_requests_ids) for seq_group_metadata in seq_group_metadata_list: builder.add_seq_group(seq_group_metadata) From 86335897b093045bbf624d66de95951f7848f832 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 22 Jan 2025 14:02:46 +0800 Subject: [PATCH 14/14] fix xpu Signed-off-by: youkaichao --- vllm/worker/xpu_model_runner.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index d1b35f8b027e1..053658d047311 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -113,7 +113,6 @@ def __init__(self, runner: "XPUModelRunner", finished_requests_ids: Optional[List[str]] = None) -> None: super().__init__() - self.seq_group_metadata_list: List[SequenceGroupMetadata] = [] self.runner = runner self.model_input_cls = self.runner._model_input_cls self.attn_backend = self.runner.attn_backend @@ -123,7 +122,7 @@ def __init__(self, def prepare(self, finished_requests_ids: Optional[List[str]] = None) -> None: - pass + self.seq_group_metadata_list: List[SequenceGroupMetadata] = [] def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): self.seq_group_metadata_list.append(seq_group_metadata)