diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index abbf6450ab7f6..c0045e647f104 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -140,13 +140,9 @@ def __init__(self, use_mrope: bool): self.input_mrope_positions: List[List[int]] = [[] for _ in range(3)] - def __init__(self, - runner: "CPUModelRunner", - finished_requests_ids: Optional[List[str]] = None) -> None: + def __init__(self, runner: "CPUModelRunner") -> None: super().__init__() - self.seq_group_metadata_list: List[SequenceGroupMetadata] = [] self.runner = runner - self.chunked_prefill = (runner.scheduler_config.chunked_prefill_enabled or runner.cache_config.enable_prefix_caching) self.model_input_cls = self.runner._model_input_cls @@ -156,11 +152,15 @@ def __init__(self, self.device = self.runner.device self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper self.enable_lora = self.runner.lora_config is not None - self.input_data = ModelInputForCPUBuilder.ModelInputData( - self.runner.model_config.uses_mrope) self.att_metadata_builder = self.runner.attn_backend.get_builder_cls()( self) + def prepare(self, + finished_requests_ids: Optional[List[str]] = None) -> None: + self.seq_group_metadata_list: List[SequenceGroupMetadata] = [] + self.input_data = ModelInputForCPUBuilder.ModelInputData( + self.runner.model_config.uses_mrope) + def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): self.seq_group_metadata_list.append(seq_group_metadata) @@ -431,6 +431,7 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]): """ _model_input_cls: Type[TModelInputForCPU] _builder_cls: Type[ModelInputForCPUBuilder] + builder: ModelInputForCPUBuilder def __init__( self, @@ -477,6 +478,8 @@ def __init__( # Set after load_model. self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None + self.builder = self._builder_cls(weakref.proxy(self)) + def load_model(self) -> None: self.model = get_model(vllm_config=self.vllm_config) @@ -522,10 +525,10 @@ def _prepare_model_input_tensors( metadata for possible additional steps, e.g., sampling. """ - builder = self._builder_cls(weakref.proxy(self), finished_requests_ids) - builder.set_seq_group_list(seq_group_metadata_list) + self.builder.prepare(finished_requests_ids) + self.builder.set_seq_group_list(seq_group_metadata_list) - return builder.build() # type: ignore + return self.builder.build() # type: ignore # sampler property will be used by spec_decode_worker @property diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index cb2ff0c934da3..1ccc68e7c33b2 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -457,14 +457,8 @@ def __init__(self, self.enable_prompt_adapter = (self.runner.prompt_adapter_config is not None) self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper - self.finished_requests_ids = finished_requests_ids self.decode_only = True - # Intermediate data (data in CPU before going to GPU) for - # the current sequence group. - self.inter_data_list: List[ - ModelInputForGPUBuilder.InterDataForSeqGroup] = [] - # Attention metadata inputs. self.attn_metadata_builder = self.attn_backend.make_metadata_builder( weakref.proxy(self)) @@ -479,6 +473,15 @@ def __init__(self, self.block_aligned_sliding_window = \ self.sliding_window_blocks * self.block_size + def prepare(self, + finished_requests_ids: Optional[List[str]] = None) -> None: + self.finished_requests_ids = finished_requests_ids + + # Intermediate data (data in CPU before going to GPU) for + # the current sequence group. + self.inter_data_list: List[ + ModelInputForGPUBuilder.InterDataForSeqGroup] = [] + def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, seq_group_metadata: SequenceGroupMetadata): """Compute context length, sequence length and tokens @@ -993,6 +996,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): """ _model_input_cls: Type[TModelInputForGPU] _builder_cls: Type[ModelInputForGPUBuilder] + builder: ModelInputForGPUBuilder def __init__( self, @@ -1093,6 +1097,8 @@ def __init__( SamplingMetadataCache() \ if self.parallel_config.pipeline_parallel_size == 1 else None + self.builder = self._builder_cls(weakref.proxy(self)) + def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) with DeviceMemoryProfiler() as m: @@ -1226,13 +1232,13 @@ def _prepare_model_input_tensors( If cuda graph is required, this API automatically pads inputs. """ - builder = self._builder_cls(weakref.proxy(self), finished_requests_ids) + self.builder.prepare(finished_requests_ids) for seq_group_metadata in seq_group_metadata_list: - builder.add_seq_group(seq_group_metadata) + self.builder.add_seq_group(seq_group_metadata) - builder.reset_cached_inter_data() + self.builder.reset_cached_inter_data() - return builder.build() # type: ignore + return self.builder.build() # type: ignore @contextmanager def set_in_profile_run(self): diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index acfd6d0b03f62..aef4bdcdd4bf9 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -200,6 +200,11 @@ class ModelRunnerInputBuilderBase(ABC, Generic[T]): """A builder to create ModelRunnerInputBase objects. """ + @abstractmethod + def prepare(self, + finished_requests_ids: Optional[List[str]] = None) -> None: + raise NotImplementedError + @abstractmethod def add_seq_group(self, seq_group_metadata): """TBA"""