Skip to content

Commit

Permalink
separate builder input and builder prepare
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao committed Jan 21, 2025
1 parent 2fc6944 commit bab73d1
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 20 deletions.
23 changes: 13 additions & 10 deletions vllm/worker/cpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,9 @@ def __init__(self, use_mrope: bool):
self.input_mrope_positions: List[List[int]] = [[]
for _ in range(3)]

def __init__(self,
runner: "CPUModelRunner",
finished_requests_ids: Optional[List[str]] = None) -> None:
def __init__(self, runner: "CPUModelRunner") -> None:
super().__init__()
self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
self.runner = runner

self.chunked_prefill = (runner.scheduler_config.chunked_prefill_enabled
or runner.cache_config.enable_prefix_caching)
self.model_input_cls = self.runner._model_input_cls
Expand All @@ -156,11 +152,15 @@ def __init__(self,
self.device = self.runner.device
self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper
self.enable_lora = self.runner.lora_config is not None
self.input_data = ModelInputForCPUBuilder.ModelInputData(
self.runner.model_config.uses_mrope)
self.att_metadata_builder = self.runner.attn_backend.get_builder_cls()(
self)

def prepare(self,
finished_requests_ids: Optional[List[str]] = None) -> None:
self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
self.input_data = ModelInputForCPUBuilder.ModelInputData(
self.runner.model_config.uses_mrope)

def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
self.seq_group_metadata_list.append(seq_group_metadata)

Expand Down Expand Up @@ -431,6 +431,7 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
"""
_model_input_cls: Type[TModelInputForCPU]
_builder_cls: Type[ModelInputForCPUBuilder]
builder: ModelInputForCPUBuilder

def __init__(
self,
Expand Down Expand Up @@ -477,6 +478,8 @@ def __init__(
# Set after load_model.
self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None

self.builder = self._builder_cls(weakref.proxy(self))

def load_model(self) -> None:
self.model = get_model(vllm_config=self.vllm_config)

Expand Down Expand Up @@ -522,10 +525,10 @@ def _prepare_model_input_tensors(
metadata for possible additional steps, e.g., sampling.
"""
builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
builder.set_seq_group_list(seq_group_metadata_list)
self.builder.prepare(finished_requests_ids)
self.builder.set_seq_group_list(seq_group_metadata_list)

return builder.build() # type: ignore
return self.builder.build() # type: ignore

# sampler property will be used by spec_decode_worker
@property
Expand Down
26 changes: 16 additions & 10 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,14 +457,8 @@ def __init__(self,
self.enable_prompt_adapter = (self.runner.prompt_adapter_config
is not None)
self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper
self.finished_requests_ids = finished_requests_ids
self.decode_only = True

# Intermediate data (data in CPU before going to GPU) for
# the current sequence group.
self.inter_data_list: List[
ModelInputForGPUBuilder.InterDataForSeqGroup] = []

# Attention metadata inputs.
self.attn_metadata_builder = self.attn_backend.make_metadata_builder(
weakref.proxy(self))
Expand All @@ -479,6 +473,15 @@ def __init__(self,
self.block_aligned_sliding_window = \
self.sliding_window_blocks * self.block_size

def prepare(self,
finished_requests_ids: Optional[List[str]] = None) -> None:
self.finished_requests_ids = finished_requests_ids

# Intermediate data (data in CPU before going to GPU) for
# the current sequence group.
self.inter_data_list: List[
ModelInputForGPUBuilder.InterDataForSeqGroup] = []

def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int,
seq_group_metadata: SequenceGroupMetadata):
"""Compute context length, sequence length and tokens
Expand Down Expand Up @@ -993,6 +996,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
"""
_model_input_cls: Type[TModelInputForGPU]
_builder_cls: Type[ModelInputForGPUBuilder]
builder: ModelInputForGPUBuilder

def __init__(
self,
Expand Down Expand Up @@ -1093,6 +1097,8 @@ def __init__(
SamplingMetadataCache() \
if self.parallel_config.pipeline_parallel_size == 1 else None

self.builder = self._builder_cls(weakref.proxy(self))

def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model)
with DeviceMemoryProfiler() as m:
Expand Down Expand Up @@ -1226,13 +1232,13 @@ def _prepare_model_input_tensors(
If cuda graph is required, this API automatically pads inputs.
"""
builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
self.builder.prepare(finished_requests_ids)
for seq_group_metadata in seq_group_metadata_list:
builder.add_seq_group(seq_group_metadata)
self.builder.add_seq_group(seq_group_metadata)

builder.reset_cached_inter_data()
self.builder.reset_cached_inter_data()

return builder.build() # type: ignore
return self.builder.build() # type: ignore

@contextmanager
def set_in_profile_run(self):
Expand Down
5 changes: 5 additions & 0 deletions vllm/worker/model_runner_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,11 @@ class ModelRunnerInputBuilderBase(ABC, Generic[T]):
"""A builder to create ModelRunnerInputBase objects.
"""

@abstractmethod
def prepare(self,
finished_requests_ids: Optional[List[str]] = None) -> None:
raise NotImplementedError

@abstractmethod
def add_seq_group(self, seq_group_metadata):
"""TBA"""
Expand Down

0 comments on commit bab73d1

Please sign in to comment.