From cdf10db849581d4b59a74d6a8b1e332468e7e23f Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Thu, 11 Jul 2024 13:15:41 -0700 Subject: [PATCH] intro common builder --- vllm/attention/backends/blocksparse_attn.py | 51 +--- vllm/attention/backends/flash_attn.py | 8 +- vllm/attention/backends/flashinfer.py | 8 +- vllm/attention/backends/rocm_flash_attn.py | 52 +--- vllm/attention/backends/utils.py | 322 +++++++++++--------- vllm/attention/backends/xformers.py | 52 +--- vllm/worker/model_runner.py | 6 +- 7 files changed, 202 insertions(+), 297 deletions(-) diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index c72ff44fdbbd6..8f03ac0cf49e5 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -1,24 +1,16 @@ from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple, Type import torch from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, - AttentionMetadataBuilder, - AttentionType) -from vllm.attention.backends.utils import ( - metadata_builder_add_seq_group_common, metadata_builder_build_common) + AttentionMetadata, AttentionType) +from vllm.attention.backends.utils import CommonMetadataBuilder from vllm.attention.ops.blocksparse_attention.interface import ( LocalStridedBlockSparseAttn, get_head_sliding_step) from vllm.attention.ops.paged_attn import PagedAttention from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.sequence import SequenceGroupMetadata - -if TYPE_CHECKING: - from vllm.worker.model_runner import (GPUModelRunnerBase, - ModelInputForGPUBuilder) @dataclass @@ -258,40 +250,9 @@ def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]: class BlocksparseFlashAttentionMetadataBuilder( - AttentionMetadataBuilder[BlocksparseFlashAttentionMetadata]): - - def __init__(self, input_builder: "ModelInputForGPUBuilder"): - self.slot_mapping: List[int] = [] - self.prefill_seq_lens: List[int] = [] - self.context_lens: List[int] = [] - self.block_tables: List[List[int]] = [] - self.decode_seq_lens: List[int] = [] - self.num_prefills = 0 - self.num_prefill_tokens = 0 - self.num_decode_tokens = 0 - - self.sliding_window = input_builder.sliding_window - self.block_size = input_builder.block_size - self.use_v2_block_manager = ( - input_builder.scheduler_config.use_v2_block_manager) - - def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata, - token_lens: List[int], seq_lens: List[int], - sliding_seq_lens: List[int], query_lens: List[int], - context_lens: List[int], - curr_sliding_window_blocks: List[int], prefix_cache_hit, - chunked_prefill_enabled): - metadata_builder_add_seq_group_common( - self, seq_group_metadata, token_lens, seq_lens, sliding_seq_lens, - query_lens, context_lens, curr_sliding_window_blocks, - prefix_cache_hit, chunked_prefill_enabled) - - def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens, - use_captured_graph: bool, cuda_graph_pad_size: int, - batch_size: int): - return metadata_builder_build_common( - self, BlocksparseFlashAttentionMetadata, runner, seq_lens, - query_lens, use_captured_graph, cuda_graph_pad_size, batch_size) + CommonMetadataBuilder[BlocksparseFlashAttentionMetadata]): + + _metadata_cls = BlocksparseFlashAttentionMetadata class BlocksparseFlashAttentionImpl(AttentionImpl): diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index c42318c927d18..caafa16684edf 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -219,17 +219,17 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata, token_lens: List[int], seq_lens: List[int], - sliding_seq_lens: List[int], query_lens: List[int], + decode_seq_lens: List[int], query_lens: List[int], context_lens: List[int], curr_sliding_window_blocks: List[int], prefix_cache_hit, chunked_prefill_enabled): is_prompt = seq_group_metadata.is_prompt block_tables = seq_group_metadata.block_tables - for (seq_id, token_len, seq_len, sliding_seq_len, query_len, + for (seq_id, token_len, seq_len, decode_seq_len, query_len, context_len, curr_sliding_window_block) in zip( seq_group_metadata.seq_data.keys(), token_lens, seq_lens, - sliding_seq_lens, query_lens, context_lens, + decode_seq_lens, query_lens, context_lens, curr_sliding_window_blocks): self.context_lens.append(context_len) @@ -242,7 +242,7 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata, "seq_len: {}, context_len: {}, query_len: {}".format( seq_len, context_len, query_len)) self.num_decode_tokens += query_len - self.decode_seq_lens.append(sliding_seq_len) + self.decode_seq_lens.append(decode_seq_len) # Compute block table. # TODO(sang): Combine chunked prefill and prefix caching by diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index eb774b9455379..edd24ae2a6fe2 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -239,7 +239,7 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata, token_lens: List[int], seq_lens: List[int], - sliding_seq_lens: List[int], query_lens: List[int], + decode_seq_lens: List[int], query_lens: List[int], context_lens: List[int], curr_sliding_window_blocks: List[int], prefix_cache_hit, chunked_prefill_enabled): @@ -247,10 +247,10 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata, block_tables = seq_group_metadata.block_tables computed_block_nums = seq_group_metadata.computed_block_nums - for (seq_id, token_len, seq_len, sliding_seq_len, query_len, + for (seq_id, token_len, seq_len, decode_seq_len, query_len, context_len, curr_sliding_window_block) in zip( seq_group_metadata.seq_data.keys(), token_lens, seq_lens, - sliding_seq_lens, query_lens, context_lens, + decode_seq_lens, query_lens, context_lens, curr_sliding_window_blocks): self.context_lens.append(context_len) if is_prompt: @@ -262,7 +262,7 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata, "seq_len: {}, context_len: {}, query_len: {}".format( seq_len, context_len, query_len)) self.num_decode_tokens += query_len - self.decode_seq_lens.append(sliding_seq_len) + self.decode_seq_lens.append(decode_seq_len) # Compute block table. # TODO(sang): Combine chunked prefill and prefix caching by diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 807efcdccccd6..f7bf0051fdf45 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -1,27 +1,19 @@ """Attention layer ROCm GPUs.""" from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple, Type import torch import vllm.envs as envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, - AttentionMetadataBuilder, - AttentionType) -from vllm.attention.backends.utils import ( - metadata_builder_add_seq_group_common, metadata_builder_build_common) + AttentionMetadata, AttentionType) +from vllm.attention.backends.utils import CommonMetadataBuilder from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger -from vllm.sequence import SequenceGroupMetadata logger = init_logger(__name__) -if TYPE_CHECKING: - from vllm.worker.model_runner import (GPUModelRunnerBase, - ModelInputForGPUBuilder) - class ROCmFlashAttentionBackend(AttentionBackend): @@ -180,41 +172,9 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: class ROCmFlashAttentionMetadataBuilder( - AttentionMetadataBuilder[ROCmFlashAttentionMetadata]): - - def __init__(self, input_builder: "ModelInputForGPUBuilder"): - self.slot_mapping: List[int] = [] - self.prefill_seq_lens: List[int] = [] - self.context_lens: List[int] = [] - self.block_tables: List[List[int]] = [] - self.decode_seq_lens: List[int] = [] - self.num_prefills = 0 - self.num_prefill_tokens = 0 - self.num_decode_tokens = 0 - - self.sliding_window = input_builder.sliding_window - self.block_size = input_builder.block_size - self.use_v2_block_manager = ( - input_builder.scheduler_config.use_v2_block_manager) - - def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata, - token_lens: List[int], seq_lens: List[int], - sliding_seq_lens: List[int], query_lens: List[int], - context_lens: List[int], - curr_sliding_window_blocks: List[int], prefix_cache_hit, - chunked_prefill_enabled): - metadata_builder_add_seq_group_common( - self, seq_group_metadata, token_lens, seq_lens, sliding_seq_lens, - query_lens, context_lens, curr_sliding_window_blocks, - prefix_cache_hit, chunked_prefill_enabled) - - def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens, - use_captured_graph: bool, cuda_graph_pad_size: int, - batch_size: int): - return metadata_builder_build_common(self, ROCmFlashAttentionMetadata, - runner, seq_lens, query_lens, - use_captured_graph, - cuda_graph_pad_size, batch_size) + CommonMetadataBuilder[ROCmFlashAttentionMetadata]): + + _metadata_cls = ROCmFlashAttentionMetadata def _make_alibi_bias(alibi_slopes: torch.Tensor, diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index df83b822820d8..cb27d56fa1985 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -1,8 +1,9 @@ """Attention backend utils""" -from typing import TYPE_CHECKING, Dict, List, Union +from typing import TYPE_CHECKING, Dict, List, Type, TypeVar, Union import torch +from vllm.attention import AttentionMetadata, AttentionMetadataBuilder from vllm.sequence import SequenceGroupMetadata from vllm.utils import make_tensor_with_pad @@ -14,7 +15,8 @@ PAD_SLOT_ID = -1 if TYPE_CHECKING: - from vllm.worker.model_runner import GPUModelRunnerBase + from vllm.worker.model_runner import (GPUModelRunnerBase, + ModelInputForGPUBuilder) def is_block_tables_empty(block_tables: Union[None, Dict]): @@ -29,13 +31,15 @@ def is_block_tables_empty(block_tables: Union[None, Dict]): return False -def compute_slot_mapping_start_idx(is_prompt, query_len, context_len, - sliding_window, use_v2_block_manager): - """TBA.""" +def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int, + context_len: int, sliding_window: int, + use_v2_block_manager: bool): + """ + Compute the start index of slot mapping. + """ start_idx = 0 if is_prompt and sliding_window is not None: - assert use_v2_block_manager \ - or context_len == 0, ( + assert use_v2_block_manager or context_len == 0, ( "Prefix caching is currently not supported with " "sliding window attention in V1 block manager") # When prefill, we use it to not write slots to kv cache @@ -44,9 +48,13 @@ def compute_slot_mapping_start_idx(is_prompt, query_len, context_len, return start_idx -def compute_slot_mapping(is_profile_run, slot_mapping, seq_id, seq_len, - context_len, start_idx, block_size, block_tables): - """TBA.""" +def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int], + seq_id: int, seq_len: int, context_len: int, + start_idx: int, block_size: int, + block_tables: Dict[int, List[int]]): + """ + Compute slot mapping. + """ if is_profile_run: # During memory profiling, the block tables are not # initialized yet. In this case, we just use a dummy @@ -70,145 +78,161 @@ def compute_slot_mapping(is_profile_run, slot_mapping, seq_id, seq_len, slot_mapping.append(slot) -def metadata_builder_add_seq_group_common( - attn_metadata_builder, seq_group_metadata: SequenceGroupMetadata, - token_lens: List[int], seq_lens: List[int], - sliding_seq_lens: List[int], query_lens: List[int], - context_lens: List[int], curr_sliding_window_blocks: List[int], - prefix_cache_hit, chunked_prefill_enabled): - is_prompt = seq_group_metadata.is_prompt - block_tables = seq_group_metadata.block_tables - computed_block_nums = seq_group_metadata.computed_block_nums - - for (seq_id, token_len, seq_len, sliding_seq_len, query_len, context_len, - curr_sliding_window_block) in zip(seq_group_metadata.seq_data.keys(), - token_lens, seq_lens, - sliding_seq_lens, query_lens, - context_lens, - curr_sliding_window_blocks): - attn_metadata_builder.context_lens.append(context_len) - if is_prompt: - attn_metadata_builder.num_prefills += 1 - attn_metadata_builder.num_prefill_tokens += token_len - attn_metadata_builder.prefill_seq_lens.append(seq_len) +TAttentionMetadata = TypeVar("TAttentionMetadata", bound='AttentionMetadata') + + +class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): + + _metadata_cls: Type[TAttentionMetadata] + + def __init__(self, input_builder: "ModelInputForGPUBuilder"): + self.slot_mapping: List[int] = [] + self.prefill_seq_lens: List[int] = [] + self.context_lens: List[int] = [] + self.block_tables: List[List[int]] = [] + self.decode_seq_lens: List[int] = [] + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.num_decode_tokens = 0 + + self.sliding_window = input_builder.sliding_window + self.block_size = input_builder.block_size + self.use_v2_block_manager = ( + input_builder.scheduler_config.use_v2_block_manager) + + def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata, + token_lens: List[int], seq_lens: List[int], + decode_seq_lens: List[int], query_lens: List[int], + context_lens: List[int], + curr_sliding_window_blocks: List[int], prefix_cache_hit, + chunked_prefill_enabled): + is_prompt = seq_group_metadata.is_prompt + block_tables = seq_group_metadata.block_tables + computed_block_nums = seq_group_metadata.computed_block_nums + + for (seq_id, token_len, seq_len, decode_seq_len, query_len, + context_len, curr_sliding_window_block) in zip( + seq_group_metadata.seq_data.keys(), token_lens, seq_lens, + decode_seq_lens, query_lens, context_lens, + curr_sliding_window_blocks): + self.context_lens.append(context_len) + if is_prompt: + self.num_prefills += 1 + self.num_prefill_tokens += token_len + self.prefill_seq_lens.append(seq_len) + else: + assert query_len == 1, ( + "seq_len: {}, context_len: {}, query_len: {}".format( + seq_len, context_len, query_len)) + self.num_decode_tokens += query_len + self.decode_seq_lens.append(decode_seq_len) + + # Compute block table. + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + block_table = [] + if prefix_cache_hit: + block_table = computed_block_nums + elif ((chunked_prefill_enabled or not is_prompt) + and block_tables is not None): + block_table = block_tables[seq_id][-curr_sliding_window_block:] + self.block_tables.append(block_table) + + # Compute slot mapping. + is_profile_run = is_block_tables_empty(block_tables) + start_idx = compute_slot_mapping_start_idx( + is_prompt, query_len, context_len, self.sliding_window, + self.use_v2_block_manager) + compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, + seq_len, context_len, start_idx, + self.block_size, + seq_group_metadata.block_tables) + + def build(self, runner: "GPUModelRunnerBase", seq_lens: List[int], + query_lens: List[int], use_captured_graph: bool, + cuda_graph_pad_size: int, batch_size: int): + device = runner.device + + logits_soft_cap = getattr(runner.model_config.hf_config, + "attn_logit_softcapping", None) + if logits_soft_cap is not None: + raise ValueError( + "Please use Flashinfer backend for models with logits_soft_cap " + "(i.e., Gemma-2). Otherwise, the output might be wrong. " + "Set Flashinfer backend by " + "export VLLM_ATTENTION_BACKEND=FLASHINFER.") + + max_query_len = max(query_lens) + max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + max_decode_seq_len = max(self.decode_seq_lens, default=0) + num_decode_tokens = self.num_decode_tokens + + if use_captured_graph: + self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) + self.block_tables.extend([] * cuda_graph_pad_size) + num_decode_tokens = batch_size + cuda_graph_pad_size + + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + input_block_tables = runner.graph_block_tables[:batch_size] + for i, block_table in enumerate(self.block_tables): + if block_table: + input_block_tables[i, :len(block_table)] = block_table + block_tables = torch.tensor(input_block_tables, device=device) else: - assert query_len == 1, ( - "seq_len: {}, context_len: {}, query_len: {}".format( - seq_len, context_len, query_len)) - attn_metadata_builder.num_decode_tokens += query_len - attn_metadata_builder.decode_seq_lens.append(sliding_seq_len) - - # Compute block table. - # TODO(sang): Combine chunked prefill and prefix caching by - # only allowing multiple of block_size chunk size. - # NOTE: This only works for oooooooxxx style attention. - block_table = [] - if prefix_cache_hit: - block_table = computed_block_nums - elif ((chunked_prefill_enabled or not is_prompt) - and block_tables is not None): - block_table = block_tables[seq_id][-curr_sliding_window_block:] - attn_metadata_builder.block_tables.append(block_table) - - # Compute slot mapping. - is_profile_run = is_block_tables_empty(block_tables) - start_idx = compute_slot_mapping_start_idx( - is_prompt, query_len, context_len, - attn_metadata_builder.sliding_window, - attn_metadata_builder.use_v2_block_manager) - compute_slot_mapping(is_profile_run, - attn_metadata_builder.slot_mapping, seq_id, - seq_len, context_len, start_idx, - attn_metadata_builder.block_size, - seq_group_metadata.block_tables) - - -def metadata_builder_build_common(attn_metadata_builder, metadata_cls, - runner: "GPUModelRunnerBase", seq_lens, - query_lens, use_captured_graph: bool, - cuda_graph_pad_size: int, batch_size: int): - device = runner.device - - logits_soft_cap = getattr(runner.model_config.hf_config, - "attn_logit_softcapping", None) - if logits_soft_cap is not None: - raise ValueError( - "Please use Flashinfer backend for models with logits_soft_cap" - " (i.e., Gemma-2). Otherwise, the output might be wrong." - " Set Flashinfer backend by " - "export VLLM_ATTENTION_BACKEND=FLASHINFER.") - - max_query_len = max(query_lens) - max_prefill_seq_len = max(attn_metadata_builder.prefill_seq_lens, - default=0) - max_decode_seq_len = max(attn_metadata_builder.decode_seq_lens, default=0) - num_decode_tokens = attn_metadata_builder.num_decode_tokens - - if use_captured_graph: - attn_metadata_builder.slot_mapping.extend([PAD_SLOT_ID] * - cuda_graph_pad_size) - attn_metadata_builder.block_tables.extend([] * cuda_graph_pad_size) - num_decode_tokens = batch_size + cuda_graph_pad_size - - # The shape of graph_block_tables is - # [max batch size, max context len // block size]. - input_block_tables = runner.graph_block_tables[:batch_size] - for i, block_table in enumerate(attn_metadata_builder.block_tables): - if block_table: - input_block_tables[i, :len(block_table)] = block_table - block_tables = torch.tensor(input_block_tables, device=device) - else: - max_block_table_len = max( - len(block_table) - for block_table in attn_metadata_builder.block_tables) - block_tables = make_tensor_with_pad( - attn_metadata_builder.block_tables, - max_len=max_block_table_len, - pad=0, - dtype=torch.int, - device=device, - ) - assert max_query_len > 0, ("query_lens: {}".format(query_lens)) - - context_lens_tensor = torch.tensor(attn_metadata_builder.context_lens, + max_block_table_len = max( + len(block_table) for block_table in self.block_tables) + block_tables = make_tensor_with_pad( + self.block_tables, + max_len=max_block_table_len, + pad=0, + dtype=torch.int, + device=device, + ) + assert max_query_len > 0, "query_lens: {}".format(query_lens) + + context_lens_tensor = torch.tensor(self.context_lens, + dtype=torch.int, + device=device) + seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device=device) - seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device=device) - query_lens_tensor = torch.tensor(query_lens, - dtype=torch.long, - device=device) - query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=device) - seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=device) - torch.cumsum(seq_lens_tensor, - dim=0, - dtype=seq_start_loc.dtype, - out=seq_start_loc[1:]) - torch.cumsum(query_lens_tensor, - dim=0, - dtype=query_start_loc.dtype, - out=query_start_loc[1:]) - - slot_mapping_tensor = torch.tensor(attn_metadata_builder.slot_mapping, - dtype=torch.long, - device=device) - - return metadata_cls( - num_prefills=attn_metadata_builder.num_prefills, - slot_mapping=slot_mapping_tensor, - num_prefill_tokens=attn_metadata_builder.num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, - max_query_len=max_query_len, - max_prefill_seq_len=max_prefill_seq_len, - max_decode_seq_len=max_decode_seq_len, - query_start_loc=query_start_loc, - seq_start_loc=seq_start_loc, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=use_captured_graph, - ) + query_lens_tensor = torch.tensor(query_lens, + dtype=torch.long, + device=device) + query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + torch.cumsum(seq_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) + torch.cumsum(query_lens_tensor, + dim=0, + dtype=query_start_loc.dtype, + out=query_start_loc[1:]) + + slot_mapping_tensor = torch.tensor(self.slot_mapping, + dtype=torch.long, + device=device) + + return self._metadata_cls( # type: ignore + num_prefills=self.num_prefills, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_prefill_seq_len=max_prefill_seq_len, + max_decode_seq_len=max_decode_seq_len, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=use_captured_graph, + ) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 1437ee51aeddd..07edd8f959607 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -1,6 +1,6 @@ """Attention layer with xFormers and PagedAttention.""" from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple, Type import torch from xformers import ops as xops @@ -10,22 +10,14 @@ LowerTriangularMaskWithTensorBias) from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, - AttentionMetadataBuilder, - AttentionType) -from vllm.attention.backends.utils import ( - metadata_builder_add_seq_group_common, metadata_builder_build_common) + AttentionMetadata, AttentionType) +from vllm.attention.backends.utils import CommonMetadataBuilder from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger -from vllm.sequence import SequenceGroupMetadata logger = init_logger(__name__) -if TYPE_CHECKING: - from vllm.worker.model_runner import (GPUModelRunnerBase, - ModelInputForGPUBuilder) - class XFormersBackend(AttentionBackend): @@ -375,41 +367,9 @@ def _get_seq_len_block_table_args( raise AttributeError(f"Invalid attention type {str(attn_type)}") -class XFormersMetadataBuilder(AttentionMetadataBuilder[XFormersMetadata]): - - def __init__(self, input_builder: "ModelInputForGPUBuilder"): - self.slot_mapping: List[int] = [] - self.prefill_seq_lens: List[int] = [] - self.context_lens: List[int] = [] - self.block_tables: List[List[int]] = [] - self.decode_seq_lens: List[int] = [] - self.num_prefills = 0 - self.num_prefill_tokens = 0 - self.num_decode_tokens = 0 - - self.sliding_window = input_builder.sliding_window - self.block_size = input_builder.block_size - self.use_v2_block_manager = ( - input_builder.scheduler_config.use_v2_block_manager) - - def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata, - token_lens: List[int], seq_lens: List[int], - sliding_seq_lens: List[int], query_lens: List[int], - context_lens: List[int], - curr_sliding_window_blocks: List[int], prefix_cache_hit, - chunked_prefill_enabled): - metadata_builder_add_seq_group_common( - self, seq_group_metadata, token_lens, seq_lens, sliding_seq_lens, - query_lens, context_lens, curr_sliding_window_blocks, - prefix_cache_hit, chunked_prefill_enabled) - - def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens, - use_captured_graph: bool, cuda_graph_pad_size: int, - batch_size: int): - return metadata_builder_build_common(self, XFormersMetadata, runner, - seq_lens, query_lens, - use_captured_graph, - cuda_graph_pad_size, batch_size) +class XFormersMetadataBuilder(CommonMetadataBuilder[XFormersMetadata]): + + _metadata_cls = XFormersMetadata class XFormersImpl(AttentionImpl[XFormersMetadata]): diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 70da1eeceba6d..5c66f0cbbdc96 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -252,7 +252,7 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): self.request_ids_to_seq_ids[seq_group_metadata.request_id] = [] token_lens = [] - sliding_seq_lens = [] + decode_seq_lens = [] context_lens = [] curr_sliding_window_blocks = [] orig_seq_lens = [] @@ -320,12 +320,12 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): # the attention metadata. token_lens.append(len(tokens)) context_lens.append(context_len) - sliding_seq_lens.append(sliding_seq_len) + decode_seq_lens.append(sliding_seq_len) curr_sliding_window_blocks.append(curr_sliding_window_block) orig_seq_lens.append(seq_len) self.attn_metadata_builder.add_seq_group( - seq_group_metadata, token_lens, orig_seq_lens, sliding_seq_lens, + seq_group_metadata, token_lens, orig_seq_lens, decode_seq_lens, self.query_lens[-n_seqs:], context_lens, curr_sliding_window_blocks, prefix_cache_hit, self.chunked_prefill_enabled)