Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] separate builder init and builder prepare for each batch #12253

Merged
merged 15 commits into from
Jan 22, 2025
11 changes: 6 additions & 5 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,6 @@ def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
raise NotImplementedError

@classmethod
def make_metadata_builder(cls, *args,
**kwargs) -> "AttentionMetadataBuilder":
return cls.get_builder_cls()(*args, **kwargs)

@staticmethod
@abstractmethod
def get_kv_cache_shape(
Expand Down Expand Up @@ -214,6 +209,12 @@ class AttentionMetadataBuilder(ABC, Generic[T]):

@abstractmethod
def __init__(self, input_builder: "ModelRunnerInputBuilderBase") -> None:
"""Create the builder, remember some configuration and parameters."""
raise NotImplementedError

@abstractmethod
def prepare(self) -> None:
"""Prepare for one batch."""
raise NotImplementedError

@abstractmethod
Expand Down
11 changes: 6 additions & 5 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,12 @@ class FlashAttentionMetadataBuilder(
AttentionMetadataBuilder[FlashAttentionMetadata]):

def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.input_builder = input_builder
self.runner = input_builder.runner
self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size

def prepare(self):
self.slot_mapping: List[int] = []
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
Expand All @@ -388,11 +394,6 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.num_decode_tokens = 0
self.has_prefix_cache_hit = False

self.input_builder = input_builder
self.runner = input_builder.runner
self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size

def _add_seq_group(
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool, prefix_cache_hit: bool):
Expand Down
14 changes: 8 additions & 6 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,14 @@ def advance_step(self,
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):

def __init__(self, input_builder: "ModelInputForGPUBuilder"):

self.input_builder = input_builder
self.runner = input_builder.runner

self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably need sm_scale too, but can go with the Flashinfer PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, this PR just makes it easier to add more values in the builder. we can add them in the flashinfer pr.


def prepare(self):
self.slot_mapping: List[int] = []
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
Expand All @@ -500,12 +508,6 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.num_prefill_tokens = 0
self.num_decode_tokens = 0

self.input_builder = input_builder
self.runner = input_builder.runner

self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size

# Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
# for the precise definition of the following fields.
# An example:
Expand Down
8 changes: 5 additions & 3 deletions vllm/attention/backends/placeholder_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,11 @@ class PlaceholderAttentionMetadataBuilder(
AttentionMetadataBuilder[PlaceholderAttentionMetadata]):

def __init__(self, input_builder: "ModelInputForGPUBuilder"):

self.input_builder = input_builder
self.runner = input_builder.runner

def prepare(self):
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
self.curr_seq_lens: List[int] = []
Expand All @@ -263,9 +268,6 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.num_prefill_tokens = 0
self.num_decode_tokens = 0

self.input_builder = input_builder
self.runner = input_builder.runner

def _add_seq_group(
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool):
Expand Down
5 changes: 4 additions & 1 deletion vllm/attention/backends/torch_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,10 @@ class TorchSDPAMetadataBuilder(AttentionMetadataBuilder[TorchSDPAMetadata]):

def __init__(self, input_builder: ModelInputForCPUBuilder) -> None:
self.chunked_prefill = input_builder.chunked_prefill
self.input_data = input_builder.input_data
self.input_builder = input_builder

def prepare(self):
self.input_data = self.input_builder.input_data

def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int) -> TorchSDPAMetadata:
Expand Down
13 changes: 7 additions & 6 deletions vllm/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,13 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
_metadata_cls: Type[TAttentionMetadata]

def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.input_builder = input_builder
self.runner = input_builder.runner

self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size

def prepare(self):
self.slot_mapping: List[int] = []
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
Expand All @@ -134,12 +141,6 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.num_prefill_tokens = 0
self.num_decode_tokens = 0

self.input_builder = input_builder
self.runner = input_builder.runner

self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size

def _add_seq_group(
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool):
Expand Down
24 changes: 17 additions & 7 deletions vllm/worker/cpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,7 @@ def __init__(self,
runner: "CPUModelRunner",
finished_requests_ids: Optional[List[str]] = None) -> None:
super().__init__()
self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
self.runner = runner

self.chunked_prefill = (runner.scheduler_config.chunked_prefill_enabled
or runner.cache_config.enable_prefix_caching)
self.model_input_cls = self.runner._model_input_cls
Expand All @@ -156,10 +154,17 @@ def __init__(self,
self.device = self.runner.device
self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper
self.enable_lora = self.runner.lora_config is not None
if self.runner.attn_backend is not None:
# spec decode (e.g. Medusa) does not have atten backend
attn_backend = self.runner.attn_backend
self.att_metadata_builder = attn_backend.get_builder_cls()(self)

def prepare(self,
finished_requests_ids: Optional[List[str]] = None) -> None:
self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
self.input_data = ModelInputForCPUBuilder.ModelInputData(
self.runner.model_config.uses_mrope)
self.att_metadata_builder = self.runner.attn_backend.get_builder_cls()(
self)
self.att_metadata_builder.prepare()

def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
self.seq_group_metadata_list.append(seq_group_metadata)
Expand Down Expand Up @@ -431,6 +436,7 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
"""
_model_input_cls: Type[TModelInputForCPU]
_builder_cls: Type[ModelInputForCPUBuilder]
builder: ModelInputForCPUBuilder

def __init__(
self,
Expand Down Expand Up @@ -477,6 +483,10 @@ def __init__(
# Set after load_model.
self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None

if hasattr(self, "_builder_cls"):
# multi-step model runner does not have `_builder_cls`
self.builder = self._builder_cls(weakref.proxy(self))

def load_model(self) -> None:
self.model = get_model(vllm_config=self.vllm_config)

Expand Down Expand Up @@ -522,10 +532,10 @@ def _prepare_model_input_tensors(
metadata for possible additional steps, e.g., sampling.

"""
builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
builder.set_seq_group_list(seq_group_metadata_list)
self.builder.prepare(finished_requests_ids)
self.builder.set_seq_group_list(seq_group_metadata_list)

return builder.build() # type: ignore
return self.builder.build() # type: ignore

# sampler property will be used by spec_decode_worker
@property
Expand Down
36 changes: 24 additions & 12 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,17 +457,13 @@ def __init__(self,
self.enable_prompt_adapter = (self.runner.prompt_adapter_config
is not None)
self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper
self.finished_requests_ids = finished_requests_ids
self.decode_only = True

# Intermediate data (data in CPU before going to GPU) for
# the current sequence group.
self.inter_data_list: List[
ModelInputForGPUBuilder.InterDataForSeqGroup] = []

# Attention metadata inputs.
self.attn_metadata_builder = self.attn_backend.make_metadata_builder(
weakref.proxy(self))
if self.attn_backend is not None:
# spec decode (e.g. Medusa) does not have atten backend
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
weakref.proxy(self))

# Engine/Model configurations.
self.chunked_prefill_enabled = (
Expand All @@ -479,6 +475,17 @@ def __init__(self,
self.block_aligned_sliding_window = \
self.sliding_window_blocks * self.block_size

def prepare(self,
finished_requests_ids: Optional[List[str]] = None) -> None:
self.finished_requests_ids = finished_requests_ids

# Intermediate data (data in CPU before going to GPU) for
# the current sequence group.
self.inter_data_list: List[
ModelInputForGPUBuilder.InterDataForSeqGroup] = []

self.attn_metadata_builder.prepare()

def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int,
seq_group_metadata: SequenceGroupMetadata):
"""Compute context length, sequence length and tokens
Expand Down Expand Up @@ -993,6 +1000,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
"""
_model_input_cls: Type[TModelInputForGPU]
_builder_cls: Type[ModelInputForGPUBuilder]
builder: ModelInputForGPUBuilder

def __init__(
self,
Expand Down Expand Up @@ -1093,6 +1101,10 @@ def __init__(
SamplingMetadataCache() \
if self.parallel_config.pipeline_parallel_size == 1 else None

if hasattr(self, "_builder_cls"):
# multi-step model runner does not have `_builder_cls`
self.builder = self._builder_cls(weakref.proxy(self))

def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model)
with DeviceMemoryProfiler() as m:
Expand Down Expand Up @@ -1226,13 +1238,13 @@ def _prepare_model_input_tensors(

If cuda graph is required, this API automatically pads inputs.
"""
builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
self.builder.prepare(finished_requests_ids)
for seq_group_metadata in seq_group_metadata_list:
builder.add_seq_group(seq_group_metadata)
self.builder.add_seq_group(seq_group_metadata)

builder.reset_cached_inter_data()
self.builder.reset_cached_inter_data()

return builder.build() # type: ignore
return self.builder.build() # type: ignore

@contextmanager
def set_in_profile_run(self):
Expand Down
5 changes: 5 additions & 0 deletions vllm/worker/model_runner_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,11 @@ class ModelRunnerInputBuilderBase(ABC, Generic[T]):
"""A builder to create ModelRunnerInputBase objects.
"""

@abstractmethod
def prepare(self,
finished_requests_ids: Optional[List[str]] = None) -> None:
raise NotImplementedError

@abstractmethod
def add_seq_group(self, seq_group_metadata):
"""TBA"""
Expand Down
10 changes: 8 additions & 2 deletions vllm/worker/xpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,17 @@ def __init__(self,
runner: "XPUModelRunner",
finished_requests_ids: Optional[List[str]] = None) -> None:
super().__init__()
self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
self.runner = runner
self.model_input_cls = self.runner._model_input_cls
self.attn_backend = self.runner.attn_backend
self.sliding_window = self.runner.sliding_window
self.block_size = self.runner.block_size
self.device = self.runner.device

def prepare(self,
finished_requests_ids: Optional[List[str]] = None) -> None:
self.seq_group_metadata_list: List[SequenceGroupMetadata] = []

def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
self.seq_group_metadata_list.append(seq_group_metadata)

Expand Down Expand Up @@ -408,6 +411,8 @@ def __init__(
SamplingMetadataCache() \
if self.parallel_config.pipeline_parallel_size == 1 else None

self.builder = self._builder_cls(weakref.proxy(self))

def load_model(self) -> None:
with DeviceMemoryProfiler() as m:
self.model = get_model(vllm_config=self.vllm_config)
Expand Down Expand Up @@ -517,7 +522,8 @@ def _prepare_model_input_tensors(
metadata for possible additional steps, e.g., sampling.

"""
builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
builder = self.builder
builder.prepare(finished_requests_ids)
for seq_group_metadata in seq_group_metadata_list:
builder.add_seq_group(seq_group_metadata)

Expand Down
Loading