diff --git a/tests/worker/test_model_input.py b/tests/worker/test_model_input.py index ae818ee360f19..2126fafb2323b 100644 --- a/tests/worker/test_model_input.py +++ b/tests/worker/test_model_input.py @@ -3,7 +3,7 @@ import torch -from vllm.attention import AttentionMetadata +from vllm.attention import AttentionMetadata, AttentionMetadataBuilder from vllm.attention.backends.abstract import AttentionBackend from vllm.model_executor import SamplingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata @@ -26,6 +26,10 @@ def get_impl_cls(): def get_metadata_cls() -> Type["AttentionMetadata"]: return AttentionMetadata + @staticmethod + def get_builder_cls() -> Type["AttentionMetadataBuilder"]: + raise AttentionMetadataBuilder + @staticmethod def get_kv_cache_shape( num_blocks: int, diff --git a/vllm/attention/__init__.py b/vllm/attention/__init__.py index f6bce9a187c64..44bfae44cfddd 100644 --- a/vllm/attention/__init__.py +++ b/vllm/attention/__init__.py @@ -1,5 +1,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, - AttentionMetadata) + AttentionMetadata, + AttentionMetadataBuilder) from vllm.attention.layer import Attention from vllm.attention.selector import get_attn_backend @@ -7,6 +8,7 @@ "Attention", "AttentionBackend", "AttentionMetadata", + "AttentionMetadataBuilder", "Attention", "get_attn_backend", ] diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index adb8325168cdf..dbda1019b45e2 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -1,11 +1,15 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, fields from enum import Enum, auto -from typing import (Any, Dict, Generic, List, Optional, Set, Tuple, Type, - TypeVar) +from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set, + Tuple, Type, TypeVar) import torch +if TYPE_CHECKING: + from vllm.sequence import SequenceGroupMetadata + from vllm.worker.model_runner_base import ModelRunnerInputBuilderBase + class AttentionType(Enum): DECODER = auto() # Decoder attention between previous layer Q/K/V @@ -35,6 +39,16 @@ def get_metadata_cls() -> Type["AttentionMetadata"]: def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata": return cls.get_metadata_cls()(*args, **kwargs) + @staticmethod + @abstractmethod + def get_builder_cls() -> Type["AttentionMetadataBuilder"]: + raise NotImplementedError + + @classmethod + def make_metadata_builder(cls, *args, + **kwargs) -> "AttentionMetadataBuilder": + return cls.get_builder_cls()(*args, **kwargs) + @staticmethod @abstractmethod def get_kv_cache_shape( @@ -110,6 +124,33 @@ def asdict_zerocopy(self, T = TypeVar("T", bound=AttentionMetadata) +class AttentionMetadataBuilder(ABC, Generic[T]): + """Abstract class for attention metadata builders.""" + + @abstractmethod + def __init__(self, input_builder) -> None: + raise NotImplementedError + + @abstractmethod + def add_seq_group(self, seq_group_metadata: "SequenceGroupMetadata", + token_lens: List[int], seq_lens: List[int], + curr_seq_lens: List[int], query_lens: List[int], + context_lens: List[int], + curr_sliding_window_blocks: List[int], + prefix_cache_hit: bool, chunked_prefill_enabled: bool): + """Add a sequence group to the metadata and update + corresponding fields (in Python objects). + """ + raise NotImplementedError + + @abstractmethod + def build(self, runner: "ModelRunnerInputBuilderBase", seq_lens: List[int], + query_lens: List[int], cuda_graph_pad_size: int, + batch_size: int) -> T: + """Build attention metadata with on-device tensors.""" + raise NotImplementedError + + class AttentionImpl(ABC, Generic[T]): @abstractmethod diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index fe4c4a45dca0d..8f03ac0cf49e5 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -5,6 +5,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) +from vllm.attention.backends.utils import CommonMetadataBuilder from vllm.attention.ops.blocksparse_attention.interface import ( LocalStridedBlockSparseAttn, get_head_sliding_step) from vllm.attention.ops.paged_attn import PagedAttention @@ -93,6 +94,10 @@ def get_impl_cls() -> Type["BlocksparseFlashAttentionImpl"]: def get_metadata_cls() -> Type["AttentionMetadata"]: return BlocksparseFlashAttentionMetadata + @staticmethod + def get_builder_cls() -> Type["BlocksparseFlashAttentionMetadataBuilder"]: + return BlocksparseFlashAttentionMetadataBuilder + @staticmethod def get_kv_cache_shape( num_blocks: int, @@ -244,6 +249,12 @@ def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]: return self._cached_decode_metadata +class BlocksparseFlashAttentionMetadataBuilder( + CommonMetadataBuilder[BlocksparseFlashAttentionMetadata]): + + _metadata_cls = BlocksparseFlashAttentionMetadata + + class BlocksparseFlashAttentionImpl(AttentionImpl): """ If the input tensors contain prompt tokens, the layout is as follows: diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 048abed48d2e9..3be6c8bc20de3 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -1,13 +1,24 @@ """Attention layer with FlashAttention.""" from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type import torch from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType) + AttentionMetadata, + AttentionMetadataBuilder, + AttentionType) +from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, + compute_slot_mapping_start_idx, + is_block_tables_empty) +from vllm.sequence import SequenceGroupMetadata +from vllm.utils import make_tensor_with_pad + +if TYPE_CHECKING: + from vllm.worker.model_runner import (GPUModelRunnerBase, + ModelInputForGPUBuilder) class FlashAttentionBackend(AttentionBackend): @@ -28,6 +39,10 @@ def get_impl_cls() -> Type["FlashAttentionImpl"]: def get_metadata_cls() -> Type["AttentionMetadata"]: return FlashAttentionMetadata + @staticmethod + def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]: + return FlashAttentionMetadataBuilder + @staticmethod def get_kv_cache_shape( num_blocks: int, @@ -184,6 +199,170 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: return self._cached_decode_metadata +class FlashAttentionMetadataBuilder( + AttentionMetadataBuilder[FlashAttentionMetadata]): + + def __init__(self, input_builder: "ModelInputForGPUBuilder"): + self.slot_mapping: List[int] = [] + self.prefill_seq_lens: List[int] = [] + self.context_lens: List[int] = [] + self.block_tables: List[List[int]] = [] + self.curr_seq_lens: List[int] = [] + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.num_decode_tokens = 0 + + self.sliding_window = input_builder.sliding_window + self.block_size = input_builder.block_size + self.use_v2_block_manager = ( + input_builder.scheduler_config.use_v2_block_manager) + + def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata, + token_lens: List[int], seq_lens: List[int], + curr_seq_lens: List[int], query_lens: List[int], + context_lens: List[int], + curr_sliding_window_blocks: List[int], + prefix_cache_hit: bool, chunked_prefill_enabled: bool): + """Add a sequence group to the metadata. Specifically update/append + 1. context length. + 2. block table. + 3. slot mapping. + """ + is_prompt = seq_group_metadata.is_prompt + block_tables = seq_group_metadata.block_tables + + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, + curr_sliding_window_block) in zip( + seq_group_metadata.seq_data.keys(), token_lens, seq_lens, + curr_seq_lens, query_lens, context_lens, + curr_sliding_window_blocks): + self.context_lens.append(context_len) + + if is_prompt: + self.num_prefills += 1 + self.num_prefill_tokens += token_len + self.prefill_seq_lens.append(seq_len) + else: + assert query_len == 1, ( + "seq_len: {}, context_len: {}, query_len: {}".format( + seq_len, context_len, query_len)) + self.num_decode_tokens += query_len + self.curr_seq_lens.append(curr_seq_len) + + # Compute block table. + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + block_table = [] + if prefix_cache_hit: + # NOTE(woosuk): For flash-attn, the block table should + # include the entries for the incoming prefill tokens. + block_table = block_tables[seq_id] + elif ((chunked_prefill_enabled or not is_prompt) + and block_tables is not None): + block_table = block_tables[seq_id][-curr_sliding_window_block:] + self.block_tables.append(block_table) + + # Compute slot mapping. + is_profile_run = is_block_tables_empty(block_tables) + start_idx = compute_slot_mapping_start_idx( + is_prompt, query_len, context_len, self.sliding_window, + self.use_v2_block_manager) + compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, + seq_len, context_len, start_idx, + self.block_size, + seq_group_metadata.block_tables) + + def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens, + cuda_graph_pad_size: int, batch_size: int): + """Build attention metadata with on-device tensors.""" + device = runner.device + use_captured_graph = cuda_graph_pad_size != -1 + + logits_soft_cap = getattr(runner.model_config.hf_config, + "attn_logit_softcapping", None) + if logits_soft_cap is not None: + raise ValueError( + "Please use Flashinfer backend for models with logits_soft_cap" + " (i.e., Gemma-2). Otherwise, the output might be wrong." + " Set Flashinfer backend by " + "export VLLM_ATTENTION_BACKEND=FLASHINFER.") + + max_query_len = max(query_lens) + max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + max_decode_seq_len = max(self.curr_seq_lens, default=0) + num_decode_tokens = self.num_decode_tokens + + if use_captured_graph: + self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) + self.block_tables.extend([] * cuda_graph_pad_size) + num_decode_tokens = batch_size + cuda_graph_pad_size + + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + input_block_tables = runner.graph_block_tables[:batch_size] + for i, block_table in enumerate(self.block_tables): + if block_table: + input_block_tables[i, :len(block_table)] = block_table + block_tables = torch.tensor(input_block_tables, device=device) + else: + max_block_table_len = max( + len(block_table) for block_table in self.block_tables) + block_tables = make_tensor_with_pad( + self.block_tables, + max_len=max_block_table_len, + pad=0, + dtype=torch.int, + device=device, + ) + assert max_query_len > 0, ("query_lens: {}".format(query_lens)) + + context_lens_tensor = torch.tensor(self.context_lens, + dtype=torch.int, + device=device) + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=device) + query_lens_tensor = torch.tensor(query_lens, + dtype=torch.long, + device=device) + query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + torch.cumsum(seq_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) + torch.cumsum(query_lens_tensor, + dim=0, + dtype=query_start_loc.dtype, + out=query_start_loc[1:]) + + slot_mapping_tensor = torch.tensor(self.slot_mapping, + dtype=torch.long, + device=device) + + return FlashAttentionMetadata( + num_prefills=self.num_prefills, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_prefill_seq_len=max_prefill_seq_len, + max_decode_seq_len=max_decode_seq_len, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=use_captured_graph, + ) + + class FlashAttentionImpl(AttentionImpl): """ If the input tensors contain prompt tokens, the layout is as follows: diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index b27e3e40f566d..b7e568f78f818 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Set, Tuple, Type +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type try: from flashinfer import BatchDecodeWithPagedKVCacheWrapper @@ -14,7 +14,18 @@ from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType) + AttentionMetadata, + AttentionMetadataBuilder, + AttentionType) +from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, + compute_slot_mapping_start_idx, + is_block_tables_empty) +from vllm.sequence import SequenceGroupMetadata +from vllm.utils import get_kv_cache_torch_dtype, make_tensor_with_pad + +if TYPE_CHECKING: + from vllm.worker.model_runner import (GPUModelRunnerBase, + ModelInputForGPUBuilder) class FlashInferBackend(AttentionBackend): @@ -31,6 +42,10 @@ def get_impl_cls() -> Type["FlashInferImpl"]: def get_metadata_cls() -> Type["AttentionMetadata"]: return FlashInferMetadata + @staticmethod + def get_builder_cls() -> Type["FlashInferMetadataBuilder"]: + return FlashInferMetadataBuilder + @staticmethod def get_kv_cache_shape( num_blocks: int, @@ -188,6 +203,225 @@ def decode_metadata(self) -> Optional["FlashInferMetadata"]: return self +class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): + + def __init__(self, input_builder: "ModelInputForGPUBuilder"): + self.slot_mapping: List[int] = [] + self.prefill_seq_lens: List[int] = [] + self.context_lens: List[int] = [] + self.block_tables: List[List[int]] = [] + self.curr_seq_lens: List[int] = [] + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.num_decode_tokens = 0 + + self.sliding_window = input_builder.sliding_window + self.block_size = input_builder.block_size + self.use_v2_block_manager = ( + input_builder.scheduler_config.use_v2_block_manager) + + # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout + # for the precise definition of the following fields. + # An example: + # request 1, page indices [0, 5, 8] + # request 2, page indices [1, 6, 7] + # request 3, page indices [3, 4] + # paged_kv_indices is a concatenation of page indices of all requests: + # [0, 5, 8, 1, 6, 7, 3, 4] + # paged_kv_indptr is used to index into paged_kv_indices: + # [0, 3, 6, 8] + self.paged_kv_indices: List[int] = [] + # 0 at the beginning of paged_kv_indptr indicates the start of the + # first request’s page indices in the paged_kv_indices list. + self.paged_kv_indptr: List[int] = [0] + # paged_kv_last_page_len is the length of the last page of each request + self.paged_kv_last_page_len: List[int] = [] + + def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata, + token_lens: List[int], seq_lens: List[int], + curr_seq_lens: List[int], query_lens: List[int], + context_lens: List[int], + curr_sliding_window_blocks: List[int], + prefix_cache_hit: bool, chunked_prefill_enabled: bool): + """Add a sequence group to the metadata. Specifically update/append + 1. context length. + 2. block table. + 3. slot mapping. + """ + is_prompt = seq_group_metadata.is_prompt + block_tables = seq_group_metadata.block_tables + computed_block_nums = seq_group_metadata.computed_block_nums + + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, + curr_sliding_window_block) in zip( + seq_group_metadata.seq_data.keys(), token_lens, seq_lens, + curr_seq_lens, query_lens, context_lens, + curr_sliding_window_blocks): + self.context_lens.append(context_len) + if is_prompt: + self.num_prefills += 1 + self.num_prefill_tokens += token_len + self.prefill_seq_lens.append(seq_len) + else: + assert query_len == 1, ( + "seq_len: {}, context_len: {}, query_len: {}".format( + seq_len, context_len, query_len)) + self.num_decode_tokens += query_len + self.curr_seq_lens.append(curr_seq_len) + + # Compute block table. + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + block_table = [] + if prefix_cache_hit: + block_table = computed_block_nums + elif ((chunked_prefill_enabled or not is_prompt) + and block_tables is not None): + block_table = block_tables[seq_id][-curr_sliding_window_block:] + self.block_tables.append(block_table) + + is_profile_run = is_block_tables_empty(block_tables) + + # Compute slot mapping. + start_idx = compute_slot_mapping_start_idx( + is_prompt, query_len, context_len, self.sliding_window, + self.use_v2_block_manager) + compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, + seq_len, context_len, start_idx, + self.block_size, + seq_group_metadata.block_tables) + + # It is not necessary to add paged_kv_indices, paged_kv_indptr, + # and paged_kv_last_page_len for profile run because we will + # create dummy inputs. + if is_profile_run: + return + + # Get the number of valid blocks based on sequence length. + # If seq_len = 16, block_size = 16, + # block_table_bound is 1 with 1 valid block. + # If seq_len = 15, block_size = 16, + # block_table_bound is 0 + 1 with 1 valid block. + block_table_bound = seq_len // self.block_size + 1 \ + if seq_len % self.block_size != 0 \ + else seq_len // self.block_size + block_table = block_tables[seq_id] + self.paged_kv_indices.extend(block_table[:block_table_bound]) + self.paged_kv_indptr.append(self.paged_kv_indptr[-1] + + block_table_bound) + + last_page_len = seq_len % self.block_size + if last_page_len == 0: + last_page_len = self.block_size + self.paged_kv_last_page_len.append(last_page_len) + + def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens, + cuda_graph_pad_size: int, batch_size: int): + device = runner.device + use_captured_graph = cuda_graph_pad_size != -1 + + max_query_len = max(query_lens) + max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + num_decode_tokens = self.num_decode_tokens + + if use_captured_graph: + self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) + self.block_tables.extend([] * cuda_graph_pad_size) + num_decode_tokens = batch_size + cuda_graph_pad_size + + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + input_block_tables = runner.graph_block_tables[:batch_size] + for i, block_table in enumerate(self.block_tables): + if block_table: + input_block_tables[i, :len(block_table)] = block_table + block_tables = torch.tensor(input_block_tables, device=device) + + last_paged_kv_indptr = self.paged_kv_indptr[-1] + self.paged_kv_indptr.extend([last_paged_kv_indptr] * + cuda_graph_pad_size) + self.paged_kv_last_page_len.extend([0] * cuda_graph_pad_size) + else: + max_block_table_len = max( + len(block_table) for block_table in self.block_tables) + block_tables = make_tensor_with_pad( + self.block_tables, + max_len=max_block_table_len, + pad=0, + dtype=torch.int, + device=device, + ) + assert max_query_len > 0, ("query_lens: {}".format(query_lens)) + + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=device) + query_lens_tensor = torch.tensor(query_lens, + dtype=torch.long, + device=device) + query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + torch.cumsum(seq_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) + torch.cumsum(query_lens_tensor, + dim=0, + dtype=query_start_loc.dtype, + out=query_start_loc[1:]) + + slot_mapping_tensor = torch.tensor(self.slot_mapping, + dtype=torch.long, + device=device) + + logits_soft_cap = getattr(runner.model_config.hf_config, + "attn_logit_softcapping", None) + + if len(self.paged_kv_indptr) > 0: + paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices, + device="cpu", + dtype=torch.int) + paged_kv_indptr_tensor = torch.tensor(self.paged_kv_indptr, + device="cpu", + dtype=torch.int) + paged_kv_last_page_len_tensor = torch.tensor( + self.paged_kv_last_page_len, device="cpu", dtype=torch.int) + else: + paged_kv_indices_tensor = None + paged_kv_indptr_tensor = None + paged_kv_last_page_len_tensor = None + + kv_cache_dtype = get_kv_cache_torch_dtype(runner.kv_cache_dtype, + runner.model_config.dtype) + return FlashInferMetadata( + num_prefills=self.num_prefills, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + max_prefill_seq_len=max_prefill_seq_len, + block_tables=block_tables, + paged_kv_indptr=paged_kv_indptr_tensor, + paged_kv_indices=paged_kv_indices_tensor, + paged_kv_last_page_len=paged_kv_last_page_len_tensor, + num_qo_heads=runner.model_config.get_num_attention_heads( + runner.parallel_config), + num_kv_heads=runner.model_config.get_num_kv_heads( + runner.parallel_config), + head_dim=runner.model_config.get_head_size(), + page_size=self.block_size, + seq_start_loc=seq_start_loc, + query_start_loc=query_start_loc, + device=device, + data_type=kv_cache_dtype, + use_cuda_graph=use_captured_graph, + logits_soft_cap=logits_soft_cap) + + class FlashInferImpl(AttentionImpl): def __init__( diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 81b546c65c819..f7bf0051fdf45 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -7,6 +7,7 @@ import vllm.envs as envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) +from vllm.attention.backends.utils import CommonMetadataBuilder from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger @@ -28,6 +29,10 @@ def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]: def get_metadata_cls() -> Type["AttentionMetadata"]: return ROCmFlashAttentionMetadata + @staticmethod + def get_builder_cls() -> Type["ROCmFlashAttentionMetadataBuilder"]: + return ROCmFlashAttentionMetadataBuilder + @staticmethod def get_kv_cache_shape( num_blocks: int, @@ -166,6 +171,12 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: return self._cached_decode_metadata +class ROCmFlashAttentionMetadataBuilder( + CommonMetadataBuilder[ROCmFlashAttentionMetadata]): + + _metadata_cls = ROCmFlashAttentionMetadata + + def _make_alibi_bias(alibi_slopes: torch.Tensor, dtype: torch.dtype, seq_lens: Optional[List[int]], diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index a3cfc6e20748b..62d0eeb249bd4 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -1,7 +1,239 @@ """Attention backend utils""" +from typing import TYPE_CHECKING, Dict, List, Type, TypeVar, Union + +import torch + +from vllm.attention import AttentionMetadata, AttentionMetadataBuilder +from vllm.sequence import SequenceGroupMetadata +from vllm.utils import make_tensor_with_pad # Error string(s) for encoder/decoder # unsupported attention scenarios - STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported " "with encoder/decoder models.") + +PAD_SLOT_ID = -1 + +if TYPE_CHECKING: + from vllm.worker.model_runner import (GPUModelRunnerBase, + ModelInputForGPUBuilder) + + +def is_block_tables_empty(block_tables: Union[None, Dict]): + """ + Check if block_tables is None or a dictionary with all None values. + """ + if block_tables is None: + return True + if isinstance(block_tables, dict) and all( + value is None for value in block_tables.values()): + return True + return False + + +def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int, + context_len: int, sliding_window: int, + use_v2_block_manager: bool): + """ + Compute the start index of slot mapping. + """ + start_idx = 0 + if is_prompt and sliding_window is not None: + assert use_v2_block_manager or context_len == 0, ( + "Prefix caching is currently not supported with " + "sliding window attention in V1 block manager") + # When prefill, we use it to not write slots to kv cache + # to save memory. + start_idx = max(0, query_len - sliding_window) + return start_idx + + +def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int], + seq_id: int, seq_len: int, context_len: int, + start_idx: int, block_size: int, + block_tables: Dict[int, List[int]]): + """ + Compute slot mapping. + """ + if is_profile_run: + # During memory profiling, the block tables are not + # initialized yet. In this case, we just use a dummy + # slot mapping. + # In embeddings, the block tables are {seq_id: None}. + slot_mapping.extend([PAD_SLOT_ID] * seq_len) + return + + # Mask the [0, start_idx) tokens of the prompt with + # PAD_SLOT_ID, where start_idx is max(0, seq_len - + # sliding_window). For example, if the prompt len is 10, + # sliding window is 8, and block size is 4, the first two + # tokens are masked and the slot mapping will be + # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. + block_table = block_tables[seq_id] + slot_mapping.extend([PAD_SLOT_ID] * max(0, start_idx - context_len)) + for i in range(max(start_idx, context_len), seq_len): + block_number = block_table[i // block_size] + block_offset = i % block_size + slot = block_number * block_size + block_offset + slot_mapping.append(slot) + + +TAttentionMetadata = TypeVar("TAttentionMetadata", bound='AttentionMetadata') + + +class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): + + _metadata_cls: Type[TAttentionMetadata] + + def __init__(self, input_builder: "ModelInputForGPUBuilder"): + self.slot_mapping: List[int] = [] + self.prefill_seq_lens: List[int] = [] + self.context_lens: List[int] = [] + self.block_tables: List[List[int]] = [] + self.curr_seq_lens: List[int] = [] + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.num_decode_tokens = 0 + + self.sliding_window = input_builder.sliding_window + self.block_size = input_builder.block_size + self.use_v2_block_manager = ( + input_builder.scheduler_config.use_v2_block_manager) + + def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata, + token_lens: List[int], seq_lens: List[int], + curr_seq_lens: List[int], query_lens: List[int], + context_lens: List[int], + curr_sliding_window_blocks: List[int], prefix_cache_hit, + chunked_prefill_enabled): + is_prompt = seq_group_metadata.is_prompt + block_tables = seq_group_metadata.block_tables + computed_block_nums = seq_group_metadata.computed_block_nums + + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, + curr_sliding_window_block) in zip( + seq_group_metadata.seq_data.keys(), token_lens, seq_lens, + curr_seq_lens, query_lens, context_lens, + curr_sliding_window_blocks): + self.context_lens.append(context_len) + if is_prompt: + self.num_prefills += 1 + self.num_prefill_tokens += token_len + self.prefill_seq_lens.append(seq_len) + else: + assert query_len == 1, ( + "seq_len: {}, context_len: {}, query_len: {}".format( + seq_len, context_len, query_len)) + self.num_decode_tokens += query_len + self.curr_seq_lens.append(curr_seq_len) + + # Compute block table. + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + block_table = [] + if prefix_cache_hit: + block_table = computed_block_nums + elif ((chunked_prefill_enabled or not is_prompt) + and block_tables is not None): + block_table = block_tables[seq_id][-curr_sliding_window_block:] + self.block_tables.append(block_table) + + # Compute slot mapping. + is_profile_run = is_block_tables_empty(block_tables) + start_idx = compute_slot_mapping_start_idx( + is_prompt, query_len, context_len, self.sliding_window, + self.use_v2_block_manager) + compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, + seq_len, context_len, start_idx, + self.block_size, + seq_group_metadata.block_tables) + + def build(self, runner: "GPUModelRunnerBase", seq_lens: List[int], + query_lens: List[int], cuda_graph_pad_size: int, + batch_size: int): + device = runner.device + use_captured_graph = cuda_graph_pad_size != -1 + + logits_soft_cap = getattr(runner.model_config.hf_config, + "attn_logit_softcapping", None) + if logits_soft_cap is not None: + raise ValueError( + "Please use Flashinfer backend for models with logits_soft_cap " + "(i.e., Gemma-2). Otherwise, the output might be wrong. " + "Set Flashinfer backend by " + "export VLLM_ATTENTION_BACKEND=FLASHINFER.") + + max_query_len = max(query_lens) + max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + max_decode_seq_len = max(self.curr_seq_lens, default=0) + num_decode_tokens = self.num_decode_tokens + + if use_captured_graph: + self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) + self.block_tables.extend([] * cuda_graph_pad_size) + num_decode_tokens = batch_size + cuda_graph_pad_size + + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + input_block_tables = runner.graph_block_tables[:batch_size] + for i, block_table in enumerate(self.block_tables): + if block_table: + input_block_tables[i, :len(block_table)] = block_table + block_tables = torch.tensor(input_block_tables, device=device) + else: + max_block_table_len = max( + len(block_table) for block_table in self.block_tables) + block_tables = make_tensor_with_pad( + self.block_tables, + max_len=max_block_table_len, + pad=0, + dtype=torch.int, + device=device, + ) + assert max_query_len > 0, "query_lens: {}".format(query_lens) + + context_lens_tensor = torch.tensor(self.context_lens, + dtype=torch.int, + device=device) + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=device) + query_lens_tensor = torch.tensor(query_lens, + dtype=torch.long, + device=device) + query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + torch.cumsum(seq_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) + torch.cumsum(query_lens_tensor, + dim=0, + dtype=query_start_loc.dtype, + out=query_start_loc[1:]) + + slot_mapping_tensor = torch.tensor(self.slot_mapping, + dtype=torch.long, + device=device) + + return self._metadata_cls( # type: ignore + num_prefills=self.num_prefills, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_prefill_seq_len=max_prefill_seq_len, + max_decode_seq_len=max_decode_seq_len, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=use_captured_graph, + ) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 6cc5f1d1477ae..07edd8f959607 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -11,6 +11,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) +from vllm.attention.backends.utils import CommonMetadataBuilder from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger @@ -32,6 +33,10 @@ def get_impl_cls() -> Type["XFormersImpl"]: def get_metadata_cls() -> Type["AttentionMetadata"]: return XFormersMetadata + @staticmethod + def get_builder_cls() -> Type["XFormersMetadataBuilder"]: + return XFormersMetadataBuilder + @staticmethod def get_kv_cache_shape( num_blocks: int, @@ -362,6 +367,11 @@ def _get_seq_len_block_table_args( raise AttributeError(f"Invalid attention type {str(attn_type)}") +class XFormersMetadataBuilder(CommonMetadataBuilder[XFormersMetadata]): + + _metadata_cls = XFormersMetadata + + class XFormersImpl(AttentionImpl[XFormersMetadata]): """ If the input tensors contain prompt tokens, the layout is as follows: diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 084100f6c1135..8fcd85585a18f 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -7,6 +7,7 @@ import vllm.envs as envs from vllm.attention.backends.abstract import AttentionBackend from vllm.logger import init_logger +from vllm.platforms import current_platform from vllm.utils import is_cpu, is_hip, is_openvino, is_tpu, is_xpu logger = init_logger(__name__) @@ -136,7 +137,7 @@ def which_attn_to_use( selected_backend = (_Backend.ROCM_FLASH if selected_backend == _Backend.FLASH_ATTN else selected_backend) if selected_backend == _Backend.ROCM_FLASH: - if torch.cuda.get_device_capability()[0] != 9: + if current_platform.get_device_capability()[0] != 9: # not Instinct series GPUs. logger.info("flash_attn is not supported on NAVI GPUs.") else: @@ -145,7 +146,7 @@ def which_attn_to_use( # FlashAttn in NVIDIA GPUs. if selected_backend == _Backend.FLASH_ATTN: - if torch.cuda.get_device_capability()[0] < 8: + if current_platform.get_device_capability()[0] < 8: # Volta and Turing NVIDIA GPUs. logger.info( "Cannot use FlashAttention-2 backend for Volta and Turing " diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 205b4f58f7a83..75a2607d0d9c4 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -2,6 +2,7 @@ import gc import time import warnings +import weakref from collections import defaultdict from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set, Tuple, Type, TypeVar, Union) @@ -48,9 +49,9 @@ from vllm.sequence import (IntermediateTensors, SamplerOutput, SequenceGroupMetadata) from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip, - is_pin_memory_available, make_tensor_with_pad) + is_pin_memory_available) from vllm.worker.model_runner_base import ( - ModelRunnerBase, ModelRunnerInputBase, + ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, _add_attn_metadata_broadcastable_dict, _add_sampling_metadata_broadcastable_dict, _init_attn_metadata_from_tensor_dict, @@ -165,6 +166,298 @@ def from_broadcasted_tensor_dict( return cls(**tensor_dict) +class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): + """TBA""" + + def __init__(self, + runner: "GPUModelRunnerBase", + finished_requests_ids: Optional[List[str]] = None): + super().__init__() + self.runner = runner + self.model_input_cls = self.runner._model_input_cls + self.attn_backend = self.runner.attn_backend + self.scheduler_config = self.runner.scheduler_config + self.sliding_window = self.runner.sliding_window + self.block_size = self.runner.block_size + self.enable_lora = self.runner.lora_config is not None + self.enable_prompt_adapter = (self.runner.prompt_adapter_config + is not None) + self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper + self.finished_requests_ids = finished_requests_ids + self.decode_only = True + + # Common inputs. + self.input_tokens: List[int] = [] + self.input_positions: List[int] = [] + self.seq_lens: List[int] = [] + self.query_lens: List[int] = [] + self.max_decode_seq_len: int = 0 + self.request_ids_to_seq_ids: Dict[str, List[int]] = defaultdict(list) + + # LoRA inputs. + self.lora_index_mapping: List[int] = [] + self.lora_prompt_mapping: List[int] = [] + self.lora_requests: Set[LoRARequest] = set() + + # Prompt adapter inputs. + self.prompt_adapter_index_mapping: List[int] = [] + self.prompt_adapter_prompt_mapping: List[int] = [] + self.prompt_adapter_requests: Set[PromptAdapterRequest] = set() + + # Multi-modal inputs. + self.multi_modal_inputs_list: List[MultiModalInputs] = [] + + # Attention metadata inputs. + self.attn_metadata_builder = self.attn_backend.make_metadata_builder( + self) + + # Engine/Model configurations. + self.chunked_prefill_enabled = ( + self.scheduler_config is not None + and self.scheduler_config.chunked_prefill_enabled) + if self.sliding_window is not None: + self.sliding_window_blocks = ( + self.sliding_window + self.block_size - 1) // self.block_size + self.block_aligned_sliding_window = \ + self.sliding_window_blocks * self.block_size + + def _compute_len_for_sliding_window(self, seq_len: int): + curr_sliding_window_blocks = 0 + sliding_seq_len = seq_len + + # TODO(sang): This is a hack to make sliding window work with + # paged attn. We can remove it if we make paged attn kernel + # to properly handle slinding window attn. + if self.sliding_window is not None: + curr_sliding_window_blocks = self.sliding_window_blocks + if self.scheduler_config.use_v2_block_manager: + # number of elements in last block + suff_len = seq_len % self.block_size + sliding_seq_len = min( + seq_len, self.block_aligned_sliding_window + suff_len) + if suff_len > 0: + curr_sliding_window_blocks += 1 + else: + sliding_seq_len = min(seq_len, self.sliding_window) + return curr_sliding_window_blocks, sliding_seq_len + + def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): + seq_ids = list(seq_group_metadata.seq_data.keys()) + n_seqs = len(seq_ids) + is_prompt = seq_group_metadata.is_prompt + token_chunk_size = seq_group_metadata.token_chunk_size + + if is_prompt: + assert n_seqs == 1 + self.decode_only = False + + # Mapping from request IDs to sequence IDs. Used for Jamba models + # that manages the cache by itself. + self.request_ids_to_seq_ids[seq_group_metadata.request_id] = [] + # The number of input tokens in each sequence. + token_lens: List[int] = [] + # The number of tokens that are already computed. + context_lens: List[int] = [] + # The current sliding window block for each sequence. + curr_sliding_window_blocks: List[int] = [] + # The original sequence length (before applying sliding window) + # for each sequence. + orig_seq_lens: List[int] = [] + # The sequence length (may be capped to the sliding window). + curr_seq_lens: List[int] = [] + for seq_id in seq_ids: + seq_data = seq_group_metadata.seq_data[seq_id] + self.request_ids_to_seq_ids[seq_group_metadata.request_id].append( + seq_id) + computed_block_nums = seq_group_metadata.computed_block_nums + + # Check if hit prefix cache (i.e., some blocks are already computed) + # Note that prefix caching does not support sliding window. + prefix_cache_hit = (computed_block_nums is not None + and len(computed_block_nums) > 0 + and self.sliding_window is None and is_prompt) + if self.chunked_prefill_enabled and prefix_cache_hit: + raise RuntimeError( + "chunked prefill cannot be used with prefix caching now.") + + # Compute context length (the number of tokens that are + # already computed) and sequence length (total number of tokens). + seq_len = seq_data.get_len() + if is_prompt: + context_len = seq_data.get_num_computed_tokens() + else: + # get_num_computed_tokens is incorrect for spec decoding. + # So, we should have a special logic here. + # TODO(sang): Fix it. + context_len = seq_len - 1 + seq_len = min(seq_len, context_len + token_chunk_size) + + # Compute tokens. + if is_prompt: + tokens = seq_data.get_token_ids()[context_len:seq_len] + else: + # Optimization. get_token_ids requires the entire copy of + # tokens. + tokens = [seq_data.get_last_token_id()] + if prefix_cache_hit: + assert computed_block_nums is not None + context_len = len(computed_block_nums) * self.block_size + tokens = tokens[context_len:] + + # These are seq_len/context_len capped to the sliding window. + # They are passed to decode kernel. + # We still need original seq_len/context_len to compute slot + # mapping (and input position) below. + if is_prompt: + curr_sliding_window_block = 0 + sliding_seq_len = seq_len + query_len = seq_len - context_len + else: + curr_sliding_window_block, sliding_seq_len = ( + self._compute_len_for_sliding_window(seq_len)) + query_len = 1 + + self.seq_lens.append(sliding_seq_len) + if not is_prompt: + self.max_decode_seq_len = max(self.max_decode_seq_len, + sliding_seq_len) + self.query_lens.append(query_len) + self.input_tokens.extend(tokens) + self.input_positions.extend(list(range(context_len, seq_len))) + + # Intermediate data of the current sequence group for + # the attention metadata. + token_lens.append(len(tokens)) + context_lens.append(context_len) + curr_seq_lens.append(sliding_seq_len) + curr_sliding_window_blocks.append(curr_sliding_window_block) + orig_seq_lens.append(seq_len) + + # Update attention metadata. Note that input builder attributes + # (self.xxx) include all added sequences, so we need to slice + # the last n_seqs sequences. + self.attn_metadata_builder.add_seq_group( + seq_group_metadata, token_lens, orig_seq_lens, curr_seq_lens, + self.query_lens[-n_seqs:], context_lens, + curr_sliding_window_blocks, prefix_cache_hit, + self.chunked_prefill_enabled) + + # LoRA data. + if self.enable_lora: + lora_id = seq_group_metadata.lora_int_id + for query_len in self.query_lens[-n_seqs:]: + if lora_id > 0: + self.lora_requests.add(seq_group_metadata.lora_request) + self.lora_index_mapping += [lora_id] * query_len + self.lora_prompt_mapping.extend( + [lora_id] * + (query_len if seq_group_metadata.sampling_params + and seq_group_metadata.sampling_params.prompt_logprobs + is not None else 1)) + + # Prompt adapter data. Note that when is_prompt=True, + # we expect only one sequence in the group. + if self.enable_prompt_adapter: + prompt_adapter_id = seq_group_metadata.prompt_adapter_id + if prompt_adapter_id > 0 and is_prompt: + query_len = self.query_lens[-1] + self.prompt_adapter_requests.add( + seq_group_metadata.prompt_adapter_request) + + num_tokens = seq_group_metadata.\ + prompt_adapter_num_virtual_tokens + pm = [prompt_adapter_id + ] * num_tokens + [0] * (query_len - num_tokens) + self.prompt_adapter_index_mapping += pm + self.prompt_adapter_prompt_mapping.extend( + [prompt_adapter_id] * + (query_len if seq_group_metadata.sampling_params + and seq_group_metadata.sampling_params.prompt_logprobs + else 1)) + + # Multi-modal data. + mm_data = seq_group_metadata.multi_modal_data + if mm_data: + mm_kwargs = self.multi_modal_input_mapper(mm_data) + self.multi_modal_inputs_list.append(mm_kwargs) + + def build(self) -> ModelInputForGPU: + if not self.input_tokens: + return self.model_input_cls() + + batch_size = len(self.input_tokens) + use_captured_graph = ( + self.decode_only and not self.runner.model_config.enforce_eager + and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] + and self.max_decode_seq_len <= self.runner.max_seq_len_to_capture) + + # If cuda graph can be used, pad tensors accordingly. + # See `capture_model` API for more details. + # vLLM uses cuda graph only for decoding requests. + cuda_graph_pad_size = -1 + if use_captured_graph: + graph_batch_size = _get_graph_batch_size(batch_size) + assert graph_batch_size >= batch_size + cuda_graph_pad_size = graph_batch_size - batch_size + batch_size = graph_batch_size + + # Tokens and positions. + self.input_tokens.extend([0] * cuda_graph_pad_size) + self.input_positions.extend([0] * cuda_graph_pad_size) + input_tokens_tensor = torch.tensor(self.input_tokens, + dtype=torch.long, + device=self.runner.device) + input_positions_tensor = torch.tensor(self.input_positions, + dtype=torch.long, + device=self.runner.device) + + # Sequence and query lengths. + self.seq_lens.extend([1] * cuda_graph_pad_size) + + # Attention metadata. + attn_metadata = self.attn_metadata_builder.build( + self.runner, self.seq_lens, self.query_lens, cuda_graph_pad_size, + batch_size) + + # LoRA data. + if self.enable_lora: + self.lora_index_mapping.extend([0] * cuda_graph_pad_size) + lora_mapping = LoRAMapping( + self.lora_index_mapping, + self.lora_prompt_mapping, + ) + else: + lora_mapping = None + + # Prompt adapter data. + if self.enable_prompt_adapter: + self.prompt_adapter_index_mapping.extend([0] * cuda_graph_pad_size) + prompt_adapter_mapping = PromptAdapterMapping( + self.prompt_adapter_index_mapping, + self.prompt_adapter_prompt_mapping, + ) + else: + prompt_adapter_mapping = None + + # Multi-modal data. + multi_modal_kwargs = MultiModalInputs.batch( + self.multi_modal_inputs_list, device=self.runner.device) + + return self.model_input_cls( + input_tokens=input_tokens_tensor, + input_positions=input_positions_tensor, + attn_metadata=attn_metadata, + seq_lens=self.seq_lens, + query_lens=self.query_lens, + lora_mapping=lora_mapping, + lora_requests=self.lora_requests, + multi_modal_kwargs=multi_modal_kwargs, + request_ids_to_seq_ids=self.request_ids_to_seq_ids, + finished_requests_ids=self.finished_requests_ids, + prompt_adapter_mapping=prompt_adapter_mapping, + prompt_adapter_requests=self.prompt_adapter_requests) + + class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): """ Helper class for shared methods between GPU model runners. @@ -368,464 +661,11 @@ def _prepare_model_input_tensors( If cuda graph is required, this API automatically pads inputs. """ - input_tokens: List[int] = [] - input_positions: List[int] = [] - slot_mapping: List[int] = [] - lora_index_mapping: List[int] = [] - lora_prompt_mapping: List[int] = [] - lora_requests: Set[LoRARequest] = set() - prompt_adapter_index_mapping: List[int] = [] - prompt_adapter_prompt_mapping: List[int] = [] - prompt_adapter_requests: Set[PromptAdapterRequest] = set() - - seq_lens: List[int] = [] - prefill_seq_lens: List[int] = [] - decode_seq_lens: List[int] = [] - context_lens: List[int] = [] - query_lens: List[int] = [] - block_tables: List[List[int]] = [] - multi_modal_inputs_list: List[MultiModalInputs] = [] - request_ids_to_seq_ids: Dict[str, List[int]] = defaultdict(list) - decode_only = True - num_prefills = 0 - num_prefill_tokens = 0 - num_decode_tokens = 0 - - # The following fields are only for flashinfer - # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout - # for the precise definition of the following fields. - # An example: - # request 1, page indices [0, 5, 8] - # request 2, page indices [1, 6, 7] - # request 3, page indices [3, 4] - # paged_kv_indices is a concatenation of page indices of all requests: - # [0, 5, 8, 1, 6, 7, 3, 4] - # paged_kv_indptr is used to index into paged_kv_indices: - # [0, 3, 6, 8] - paged_kv_indices: List[int] = [] - # 0 at the beginning of paged_kv_indptr indicates the start of the - # first request’s page indices in the paged_kv_indices list. - paged_kv_indptr: List[int] = [0] - # paged_kv_last_page_len is the length of the last page of each request - paged_kv_last_page_len: List[int] = [] - - if len(seq_group_metadata_list) == 0: - return self._model_input_cls() - - if self.sliding_window is not None: - sliding_window_blocks = (self.sliding_window + self.block_size - - 1) // self.block_size - block_aligned_sliding_window = \ - sliding_window_blocks * self.block_size - + builder = ModelInputForGPUBuilder(weakref.proxy(self), + finished_requests_ids) for seq_group_metadata in seq_group_metadata_list: - seq_ids = list(seq_group_metadata.seq_data.keys()) - is_prompt = seq_group_metadata.is_prompt - - for seq_id in seq_ids: - computed_block_nums = seq_group_metadata.computed_block_nums - if (self.scheduler_config is not None - and self.scheduler_config.chunked_prefill_enabled - and not (computed_block_nums is None - or computed_block_nums == [])): - raise RuntimeError( - "chunked prefill cannot be used with prefix caching " - "now.") - - seq_data = seq_group_metadata.seq_data[seq_id] - if is_prompt: - context_len = seq_data.get_num_computed_tokens() - else: - # get_num_computed_tokens is incorrect for spec decoding. - # So, we should have a special logic here. - # TODO(sang): Fix it. - context_len = seq_data.get_len() - 1 - - seq_len = min( - seq_data.get_len(), - context_len + seq_group_metadata.token_chunk_size) - if is_prompt: - tokens = seq_data.get_token_ids()[context_len:seq_len] - else: - # Optimization. get_token_ids requires the entire copy of - # tokens. - tokens = [seq_data.get_last_token_id()] - - # Prefix cache was hit. - # Prefix is not supported with sliding_window - prefix_cache_hit = (computed_block_nums is not None - and len(computed_block_nums) > 0 - and self.sliding_window is None - and is_prompt) - - # These are seq_len/context_len capped to the sliding window. - # They are passed to decode kernel. - # We still need original seq_len/context_len to compute slot - # mapping (and input position) below. - curr_sliding_window_blocks = None - sliding_seq_len = seq_len - sliding_context_len = context_len - - # TODO(sang): This is a hack to make sliding window work with - # paged attn. We can remove it if we make paged attn kernel - # to properly handle slinding window attn. - if (self.sliding_window is not None and not is_prompt): - curr_sliding_window_blocks = sliding_window_blocks - if self.scheduler_config.use_v2_block_manager: - # number of elements in last block - suff_len = seq_len % self.block_size - sliding_seq_len = min( - seq_len, block_aligned_sliding_window + suff_len) - if suff_len > 0: - curr_sliding_window_blocks += 1 - else: - sliding_seq_len = min(seq_len, self.sliding_window) - sliding_context_len = sliding_seq_len - 1 - - # TODO(sang): Combine chunked prefill and prefix caching by - # only allowing multiple of block_size chunk size. - # NOTE: This only works for oooooooxxx style attention. - if prefix_cache_hit: - assert computed_block_nums is not None - context_len = len(computed_block_nums) * self.block_size - tokens = tokens[context_len:] - - # need to think what to set it to when we have both sliding - # window and prefix caching... - assert self.sliding_window is None, \ - "Prefix caching is not supported with sliding window" - sliding_context_len = context_len - - if self.attn_backend.get_name() == "flash-attn": - # NOTE(woosuk): For flash-attn, the block table should - # include the entries for the incoming prefill tokens. - # TODO(woosuk): This is a temporary fix. We should - # provide a unified interface for different backends. - block_table = seq_group_metadata.block_tables[seq_id] - else: - block_table = computed_block_nums - elif (self.scheduler_config.chunked_prefill_enabled - or not is_prompt): - if seq_group_metadata.block_tables is not None: - # chunked prefill or decode - block_table = seq_group_metadata.block_tables[seq_id] - if curr_sliding_window_blocks is not None: - block_table = block_table[ - -curr_sliding_window_blocks:] - else: - # Only happens when memory profiling runs. - block_table = [] - else: - # Prefill without chunked prefill or memory profiling. - block_table = [] - block_tables.append(block_table) - - seq_lens.append(sliding_seq_len) - context_lens.append(sliding_context_len) - query_len = sliding_seq_len - sliding_context_len - query_lens.append(query_len) - input_tokens.extend(tokens) - input_positions.extend(list(range(context_len, seq_len))) - lora_id = seq_group_metadata.lora_int_id - prompt_adapter_id = seq_group_metadata.prompt_adapter_id - - if is_prompt: - assert len(seq_ids) == 1 - num_prefills += 1 - num_prefill_tokens += len(tokens) - decode_only = False - prefill_seq_lens.append(seq_len) - else: - assert query_len == 1, ( - "seq_len: {}, context_len: {}, query_len: {}".format( - seq_len, context_len, query_len)) - num_decode_tokens += query_len - decode_seq_lens.append(sliding_seq_len) - - if lora_id > 0: - lora_requests.add(seq_group_metadata.lora_request) - - lora_index_mapping += [lora_id] * query_len - lora_prompt_mapping.extend( - [lora_id] * - (query_len if seq_group_metadata.sampling_params - and seq_group_metadata.sampling_params.prompt_logprobs - is not None else 1)) - - mm_data = seq_group_metadata.multi_modal_data - if mm_data: - # Process multi-modal data - mm_kwargs = self.multi_modal_input_mapper(mm_data) - multi_modal_inputs_list.append(mm_kwargs) - - if prompt_adapter_id > 0 and is_prompt: - prompt_adapter_requests.add( - seq_group_metadata.prompt_adapter_request) - - num_tokens = seq_group_metadata.\ - prompt_adapter_num_virtual_tokens - pm = [prompt_adapter_id - ] * num_tokens + [0] * (query_len - num_tokens) - prompt_adapter_index_mapping += pm - prompt_adapter_prompt_mapping.extend( - [prompt_adapter_id] * - (query_len if seq_group_metadata.sampling_params - and seq_group_metadata.sampling_params.prompt_logprobs - else 1)) - - is_profile_run = _is_block_tables_empty( - seq_group_metadata.block_tables) - if is_profile_run: - # During memory profiling, the block tables are not - # initialized yet. In this case, we just use a dummy - # slot mapping. - # In embeddings, the block tables are {seq_id: None}. - slot_mapping.extend([_PAD_SLOT_ID] * seq_len) - continue - - # Compute the slot mapping. - block_table = seq_group_metadata.block_tables[seq_id] - - # Mask the [0, start_idx) tokens of the prompt with - # _PAD_SLOT_ID, where start_idx is max(0, seq_len - - # sliding_window). For example, if the prompt len is 10, - # sliding window is 8, and block size is 4, the first two - # tokens are masked and the slot mapping will be - # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. - start_idx = 0 - if self.sliding_window is not None: - if is_prompt: - assert self.scheduler_config.use_v2_block_manager \ - or context_len == 0, ( - "Prefix caching is currently not supported with " - "sliding window attention in V1 block manager") - # It is an optimization. When it is decoding, it is always - # 0. When prefill, we use it to not write slots to kv cache - # to save memory. - start_idx = max(0, query_len - self.sliding_window) - - for i in range(context_len, seq_len): - if i < start_idx: - slot_mapping.append(_PAD_SLOT_ID) - continue - - block_number = block_table[i // self.block_size] - block_offset = i % self.block_size - slot = block_number * self.block_size + block_offset - slot_mapping.append(slot) - - # Prepare input tensors for flashinfer - if self.attn_backend.get_name() == "flashinfer": - seq_len = seq_data.get_len() - # Get the number of valid blocks based on sequence length. - # If seq_len = 16, block_size = 16, - # block_table_bound is 1 with 1 valid block. - # If seq_len = 15, block_size = 16, - # block_table_bound is 0 + 1 with 1 valid block. - block_table_bound = seq_len // self.block_size + 1 \ - if seq_len % self.block_size != 0 \ - else seq_len // self.block_size - - paged_kv_indices.extend(block_table[:block_table_bound]) - paged_kv_indptr.append(paged_kv_indptr[-1] + - block_table_bound) - - last_page_len = seq_len % self.block_size - if last_page_len == 0: - last_page_len = self.block_size - paged_kv_last_page_len.append(last_page_len) - - batch_size = len(input_tokens) - max_query_len = max(query_lens) - max_prefill_seq_len = max(prefill_seq_lens, default=0) - max_decode_seq_len = max(decode_seq_lens, default=0) - - # If cuda graph can be used, pad tensors accordingly. - # See `capture_model` API for more details. - # vLLM uses cuda graph only for decoding requests. - use_captured_graph = ( - decode_only and not self.model_config.enforce_eager - and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] - and max_decode_seq_len <= self.max_seq_len_to_capture) - if use_captured_graph: - graph_batch_size = _get_graph_batch_size(batch_size) - assert graph_batch_size >= batch_size - for _ in range(graph_batch_size - batch_size): - input_tokens.append(0) - input_positions.append(0) - slot_mapping.append(_PAD_SLOT_ID) - seq_lens.append(1) - block_tables.append([]) - lora_index_mapping.append(0) - prompt_adapter_index_mapping.append(0) - if self.attn_backend.get_name() == "flashinfer": - last_paged_kv_indptr = paged_kv_indptr[-1] - paged_kv_indptr.append(last_paged_kv_indptr) - paged_kv_last_page_len.append(0) - batch_size = graph_batch_size - num_decode_tokens = batch_size - - if use_captured_graph: - # The shape of graph_block_tables is - # [max batch size, max context len // block size]. - input_block_tables = self.graph_block_tables[:batch_size] - for i, block_table in enumerate(block_tables): - if block_table: - input_block_tables[i, :len(block_table)] = block_table - block_tables = torch.tensor(input_block_tables, device=self.device) - else: - max_block_table_len = max( - len(block_table) for block_table in block_tables) - block_tables = make_tensor_with_pad( - block_tables, - max_len=max_block_table_len, - pad=0, - dtype=torch.int, - device=self.device, - ) - assert max_query_len > 0, ("query_lens: {}".format(query_lens)) - - context_lens_tensor = torch.tensor(context_lens, - dtype=torch.int, - device=self.device) - - seq_lens_tensor = torch.tensor(seq_lens, - dtype=torch.int, - device=self.device) - query_lens_tensor = torch.tensor(query_lens, - dtype=torch.long, - device=self.device) - query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=self.device) - seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=self.device) - - torch.cumsum(seq_lens_tensor, - dim=0, - dtype=seq_start_loc.dtype, - out=seq_start_loc[1:]) - torch.cumsum(query_lens_tensor, - dim=0, - dtype=query_start_loc.dtype, - out=query_start_loc[1:]) - - input_tokens_tensor = torch.tensor(input_tokens, - dtype=torch.long, - device=self.device) - input_positions_tensor = torch.tensor(input_positions, - dtype=torch.long, - device=self.device) - slot_mapping_tensor = torch.tensor(slot_mapping, - dtype=torch.long, - device=self.device) - - logits_soft_cap = getattr(self.model_config.hf_config, - 'attn_logit_softcapping', None) - if logits_soft_cap is not None and self.attn_backend.get_name( - ) != "flashinfer": - raise ValueError("Please use Flashinfer backend for models with" - "logits_soft_cap (i.e., Gemma-2)." - " Otherwise, the output might be wrong." - " Set Flashinfer backend by " - "export VLLM_ATTENTION_BACKEND=FLASHINFER.") - - if self.attn_backend.get_name() == "flashinfer": - if len(paged_kv_indptr) > 0: - paged_kv_indices_tensor = torch.tensor(paged_kv_indices, - device='cpu', - dtype=torch.int) - paged_kv_indptr_tensor = torch.tensor(paged_kv_indptr, - device='cpu', - dtype=torch.int) - paged_kv_last_page_len_tensor = torch.tensor( - paged_kv_last_page_len, device='cpu', dtype=torch.int) - else: - paged_kv_indices_tensor = None - paged_kv_indptr_tensor = None - paged_kv_last_page_len_tensor = None - - kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype, - self.model_config.dtype) - attn_metadata = self.attn_backend.make_metadata( - num_prefills=num_prefills, - slot_mapping=slot_mapping_tensor, - num_prefill_tokens=num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - max_prefill_seq_len=max_prefill_seq_len, - block_tables=block_tables, - paged_kv_indptr=paged_kv_indptr_tensor, - paged_kv_indices=paged_kv_indices_tensor, - paged_kv_last_page_len=paged_kv_last_page_len_tensor, - num_qo_heads=self.model_config.get_num_attention_heads( - self.parallel_config), - num_kv_heads=self.model_config.get_num_kv_heads( - self.parallel_config), - head_dim=self.model_config.get_head_size(), - page_size=self.block_size, - seq_start_loc=seq_start_loc, - query_start_loc=query_start_loc, - device=self.device, - data_type=kv_cache_dtype, - use_cuda_graph=use_captured_graph, - logits_soft_cap=logits_soft_cap) - - else: - attn_metadata = self.attn_backend.make_metadata( - num_prefills=num_prefills, - slot_mapping=slot_mapping_tensor, - num_prefill_tokens=num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, - max_query_len=max_query_len, - max_prefill_seq_len=max_prefill_seq_len, - max_decode_seq_len=max_decode_seq_len, - query_start_loc=query_start_loc, - seq_start_loc=seq_start_loc, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=use_captured_graph, - ) - - if self.lora_config: - lora_mapping = LoRAMapping( - lora_index_mapping, - lora_prompt_mapping, - ) - else: - lora_mapping = None - - if self.prompt_adapter_config: - prompt_adapter_mapping = PromptAdapterMapping( - prompt_adapter_index_mapping, - prompt_adapter_prompt_mapping, - ) - else: - prompt_adapter_mapping = None - - multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list, - device=self.device) - request_ids_to_seq_ids = { - seq_group_metadata.request_id: - list(seq_group_metadata.seq_data.keys()) - for seq_group_metadata in seq_group_metadata_list - } - return self._model_input_cls( - input_tokens=input_tokens_tensor, - input_positions=input_positions_tensor, - attn_metadata=attn_metadata, - seq_lens=seq_lens, - query_lens=query_lens, - lora_mapping=lora_mapping, - lora_requests=lora_requests, - multi_modal_kwargs=multi_modal_kwargs, - request_ids_to_seq_ids=request_ids_to_seq_ids, - finished_requests_ids=finished_requests_ids, - prompt_adapter_mapping=prompt_adapter_mapping, - prompt_adapter_requests=prompt_adapter_requests, - ) + builder.add_seq_group(seq_group_metadata) + return builder.build() # type: ignore @torch.inference_mode() def profile_run(self) -> None: diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index bc0960fa16221..bc7a6a73b17c4 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -113,6 +113,21 @@ def from_broadcasted_tensor_dict( raise NotImplementedError +class ModelRunnerInputBuilderBase(ABC, Generic[T]): + """A builder to create ModelRunnerInputBase objects. + """ + + @abstractmethod + def add_seq_group(self, seq_group_metadata): + """TBA""" + raise NotImplementedError + + @abstractmethod + def build(self, *args, **kwargs) -> T: + """Build metadata with on-device tensors.""" + raise NotImplementedError + + class ModelRunnerBase(ABC, Generic[T]): """ Model runner interface that abstracts a particular hardware and/or type of