Skip to content

Commit

Permalink
intro common builder
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac committed Jul 11, 2024
1 parent cf3b724 commit cdf10db
Show file tree
Hide file tree
Showing 7 changed files with 202 additions and 297 deletions.
51 changes: 6 additions & 45 deletions vllm/attention/backends/blocksparse_attn.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,16 @@
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
from typing import Any, Dict, List, Optional, Tuple, Type

import torch

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata,
AttentionMetadataBuilder,
AttentionType)
from vllm.attention.backends.utils import (
metadata_builder_add_seq_group_common, metadata_builder_build_common)
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonMetadataBuilder
from vllm.attention.ops.blocksparse_attention.interface import (
LocalStridedBlockSparseAttn, get_head_sliding_step)
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.sequence import SequenceGroupMetadata

if TYPE_CHECKING:
from vllm.worker.model_runner import (GPUModelRunnerBase,
ModelInputForGPUBuilder)


@dataclass
Expand Down Expand Up @@ -258,40 +250,9 @@ def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]:


class BlocksparseFlashAttentionMetadataBuilder(
AttentionMetadataBuilder[BlocksparseFlashAttentionMetadata]):

def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.slot_mapping: List[int] = []
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
self.block_tables: List[List[int]] = []
self.decode_seq_lens: List[int] = []
self.num_prefills = 0
self.num_prefill_tokens = 0
self.num_decode_tokens = 0

self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
self.use_v2_block_manager = (
input_builder.scheduler_config.use_v2_block_manager)

def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
token_lens: List[int], seq_lens: List[int],
sliding_seq_lens: List[int], query_lens: List[int],
context_lens: List[int],
curr_sliding_window_blocks: List[int], prefix_cache_hit,
chunked_prefill_enabled):
metadata_builder_add_seq_group_common(
self, seq_group_metadata, token_lens, seq_lens, sliding_seq_lens,
query_lens, context_lens, curr_sliding_window_blocks,
prefix_cache_hit, chunked_prefill_enabled)

def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens,
use_captured_graph: bool, cuda_graph_pad_size: int,
batch_size: int):
return metadata_builder_build_common(
self, BlocksparseFlashAttentionMetadata, runner, seq_lens,
query_lens, use_captured_graph, cuda_graph_pad_size, batch_size)
CommonMetadataBuilder[BlocksparseFlashAttentionMetadata]):

_metadata_cls = BlocksparseFlashAttentionMetadata


class BlocksparseFlashAttentionImpl(AttentionImpl):
Expand Down
8 changes: 4 additions & 4 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,17 +219,17 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"):

def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
token_lens: List[int], seq_lens: List[int],
sliding_seq_lens: List[int], query_lens: List[int],
decode_seq_lens: List[int], query_lens: List[int],
context_lens: List[int],
curr_sliding_window_blocks: List[int], prefix_cache_hit,
chunked_prefill_enabled):
is_prompt = seq_group_metadata.is_prompt
block_tables = seq_group_metadata.block_tables

for (seq_id, token_len, seq_len, sliding_seq_len, query_len,
for (seq_id, token_len, seq_len, decode_seq_len, query_len,
context_len, curr_sliding_window_block) in zip(
seq_group_metadata.seq_data.keys(), token_lens, seq_lens,
sliding_seq_lens, query_lens, context_lens,
decode_seq_lens, query_lens, context_lens,
curr_sliding_window_blocks):
self.context_lens.append(context_len)

Expand All @@ -242,7 +242,7 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
"seq_len: {}, context_len: {}, query_len: {}".format(
seq_len, context_len, query_len))
self.num_decode_tokens += query_len
self.decode_seq_lens.append(sliding_seq_len)
self.decode_seq_lens.append(decode_seq_len)

# Compute block table.
# TODO(sang): Combine chunked prefill and prefix caching by
Expand Down
8 changes: 4 additions & 4 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,18 +239,18 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"):

def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
token_lens: List[int], seq_lens: List[int],
sliding_seq_lens: List[int], query_lens: List[int],
decode_seq_lens: List[int], query_lens: List[int],
context_lens: List[int],
curr_sliding_window_blocks: List[int], prefix_cache_hit,
chunked_prefill_enabled):
is_prompt = seq_group_metadata.is_prompt
block_tables = seq_group_metadata.block_tables
computed_block_nums = seq_group_metadata.computed_block_nums

for (seq_id, token_len, seq_len, sliding_seq_len, query_len,
for (seq_id, token_len, seq_len, decode_seq_len, query_len,
context_len, curr_sliding_window_block) in zip(
seq_group_metadata.seq_data.keys(), token_lens, seq_lens,
sliding_seq_lens, query_lens, context_lens,
decode_seq_lens, query_lens, context_lens,
curr_sliding_window_blocks):
self.context_lens.append(context_len)
if is_prompt:
Expand All @@ -262,7 +262,7 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
"seq_len: {}, context_len: {}, query_len: {}".format(
seq_len, context_len, query_len))
self.num_decode_tokens += query_len
self.decode_seq_lens.append(sliding_seq_len)
self.decode_seq_lens.append(decode_seq_len)

# Compute block table.
# TODO(sang): Combine chunked prefill and prefix caching by
Expand Down
52 changes: 6 additions & 46 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,19 @@
"""Attention layer ROCm GPUs."""
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
from typing import Any, Dict, List, Optional, Tuple, Type

import torch

import vllm.envs as envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata,
AttentionMetadataBuilder,
AttentionType)
from vllm.attention.backends.utils import (
metadata_builder_add_seq_group_common, metadata_builder_build_common)
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonMetadataBuilder
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
from vllm.logger import init_logger
from vllm.sequence import SequenceGroupMetadata

logger = init_logger(__name__)

if TYPE_CHECKING:
from vllm.worker.model_runner import (GPUModelRunnerBase,
ModelInputForGPUBuilder)


class ROCmFlashAttentionBackend(AttentionBackend):

Expand Down Expand Up @@ -180,41 +172,9 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:


class ROCmFlashAttentionMetadataBuilder(
AttentionMetadataBuilder[ROCmFlashAttentionMetadata]):

def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.slot_mapping: List[int] = []
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
self.block_tables: List[List[int]] = []
self.decode_seq_lens: List[int] = []
self.num_prefills = 0
self.num_prefill_tokens = 0
self.num_decode_tokens = 0

self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
self.use_v2_block_manager = (
input_builder.scheduler_config.use_v2_block_manager)

def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
token_lens: List[int], seq_lens: List[int],
sliding_seq_lens: List[int], query_lens: List[int],
context_lens: List[int],
curr_sliding_window_blocks: List[int], prefix_cache_hit,
chunked_prefill_enabled):
metadata_builder_add_seq_group_common(
self, seq_group_metadata, token_lens, seq_lens, sliding_seq_lens,
query_lens, context_lens, curr_sliding_window_blocks,
prefix_cache_hit, chunked_prefill_enabled)

def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens,
use_captured_graph: bool, cuda_graph_pad_size: int,
batch_size: int):
return metadata_builder_build_common(self, ROCmFlashAttentionMetadata,
runner, seq_lens, query_lens,
use_captured_graph,
cuda_graph_pad_size, batch_size)
CommonMetadataBuilder[ROCmFlashAttentionMetadata]):

_metadata_cls = ROCmFlashAttentionMetadata


def _make_alibi_bias(alibi_slopes: torch.Tensor,
Expand Down
Loading

0 comments on commit cdf10db

Please sign in to comment.