Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft][Core] Refactor _prepare_model_input_tensors #5972

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import (Any, Dict, Generic, List, Optional, Set, Tuple, Type,
TypeVar)

import numpy as np
import torch


Expand All @@ -28,6 +29,16 @@ def get_metadata_cls() -> Type["AttentionMetadata"]:
def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
return cls.get_metadata_cls()(*args, **kwargs)

@staticmethod
@abstractmethod
def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
raise NotImplementedError

@classmethod
def make_metadata_builder(cls, *args,
**kwargs) -> "AttentionMetadataBuilder":
return cls.get_builder_cls()(*args, **kwargs)

@staticmethod
@abstractmethod
def get_kv_cache_shape(
Expand Down Expand Up @@ -103,6 +114,30 @@ def asdict_zerocopy(self,
T = TypeVar("T", bound=AttentionMetadata)


class AttentionMetadataBuilder(ABC, Generic[T]):
"""Abstract class for attention metadata builders."""

@abstractmethod
def __init__(self, input_builder) -> None:
raise NotImplementedError

@abstractmethod
def add_prefill_seq_group(self, *args, **kwargs) -> None:
raise NotImplementedError

@abstractmethod
def add_decode_seq_group(self, *args, **kwargs) -> None:
raise NotImplementedError

@abstractmethod
def build(self, model_config: Any, parallel_config: Any,
kv_cache_dtype: Any, seq_lens: Any, query_lens: Any,
decode_seq_lens: Any, use_captured_graph: bool,
cuda_graph_pad_size: int, graph_block_tables: np.ndarray,
batch_size: int, device: Any) -> T:
raise NotImplementedError


class AttentionImpl(ABC, Generic[T]):

@abstractmethod
Expand Down
198 changes: 196 additions & 2 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
"""Attention layer with FlashAttention."""
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type

import numpy as np
import torch
from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache

from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata)
AttentionMetadata,
AttentionMetadataBuilder)
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
is_block_tables_empty)
from vllm.sequence import SequenceGroupMetadata
from vllm.utils import make_tensor_with_pad

if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUBuilder


class FlashAttentionBackend(AttentionBackend):
Expand All @@ -28,6 +37,10 @@ def get_impl_cls() -> Type["FlashAttentionImpl"]:
def get_metadata_cls() -> Type["AttentionMetadata"]:
return FlashAttentionMetadata

@staticmethod
def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]:
return FlashAttentionMetadataBuilder

@staticmethod
def get_kv_cache_shape(
num_blocks: int,
Expand Down Expand Up @@ -184,6 +197,187 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]:
return self._cached_decode_metadata


class FlashAttentionMetadataBuilder(
AttentionMetadataBuilder[FlashAttentionMetadata]):

def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.slot_mapping: List[int] = []
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
self.block_tables: List[List[int]] = []
self.num_prefills = 0
self.num_prefill_tokens = 0
self.num_decode_tokens = 0

self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
self.use_v2_block_manager = (
input_builder.scheduler_config.use_v2_block_manager)

def add_prefill_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
tokens: List[int], seq_id: int, seq_len: int,
query_len: int, context_len: int,
prefix_cache_hit, chunked_prefill_enabled,
computed_block_nums,
curr_sliding_window_blocks) -> None:

# Compute block table.
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
if prefix_cache_hit:
assert computed_block_nums is not None
assert self.sliding_window is None
# NOTE(woosuk): For flash-attn, the block table should
# include the entries for the incoming prefill tokens.
block_table = seq_group_metadata.block_tables[seq_id]
elif (chunked_prefill_enabled
and seq_group_metadata.block_tables is not None):
block_table = seq_group_metadata.block_tables[seq_id]
if curr_sliding_window_blocks is not None:
block_table = block_table[-curr_sliding_window_blocks:]
else:
# Prefill without chunked prefill or memory profiling.
block_table = []

self.block_tables.append(block_table)
self.context_lens.append(context_len)

self.num_prefills += 1
self.num_prefill_tokens += len(tokens)
self.prefill_seq_lens.append(seq_len)

# Compute slot mapping.
block_table = None
is_profile_run = is_block_tables_empty(seq_group_metadata.block_tables)
if not is_profile_run:
block_table = seq_group_metadata.block_tables[seq_id]

start_idx = 0
if self.sliding_window is not None:
assert self.use_v2_block_manager \
or context_len == 0, (
"Prefix caching is currently not supported with "
"sliding window attention in V1 block manager")
# When prefill, we use it to not write slots to kv cache
# to save memory.
start_idx = max(0, query_len - self.sliding_window)

compute_slot_mapping(self.slot_mapping, seq_len, context_len,
start_idx, self.block_size, block_table)

def add_decode_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
seq_id, seq_len, query_len, context_len,
curr_sliding_window_blocks, sliding_seq_len,
sliding_context_len):

# Compute block table.
if seq_group_metadata.block_tables is not None:
block_table = seq_group_metadata.block_tables[seq_id]
if curr_sliding_window_blocks is not None:
# NOTE(woosuk): For flash-attn, the block table should
# include the entries for the incoming prefill tokens.
block_table = block_table[-curr_sliding_window_blocks:]
else:
# Only happens when memory profiling runs.
block_table = []

self.block_tables.append(block_table)
self.context_lens.append(sliding_context_len)

assert query_len == 1, (
"seq_len: {}, context_len: {}, query_len: {}".format(
seq_len, context_len, query_len))
self.num_decode_tokens += query_len

# Compute the slot mapping.
block_table = None
is_profile_run = is_block_tables_empty(seq_group_metadata.block_tables)
if not is_profile_run:
block_table = seq_group_metadata.block_tables[seq_id]

compute_slot_mapping(self.slot_mapping, seq_len, context_len, 0,
self.block_size, block_table)

def build(self, model_config, parallel_config, kv_cache_dtype, seq_lens,
query_lens, decode_seq_lens, use_captured_graph: bool,
cuda_graph_pad_size: int, graph_block_tables: np.ndarray,
batch_size: int, device):
max_query_len = max(query_lens)
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
max_decode_seq_len = max(decode_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens

if use_captured_graph:
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
self.block_tables.extend([] * cuda_graph_pad_size)
num_decode_tokens = batch_size + cuda_graph_pad_size

# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables = graph_block_tables[:batch_size]
for i, block_table in enumerate(self.block_tables):
if block_table:
input_block_tables[i, :len(block_table)] = block_table
block_tables = torch.tensor(input_block_tables, device=device)
else:
max_block_table_len = max(
len(block_table) for block_table in self.block_tables)
block_tables = make_tensor_with_pad(
self.block_tables,
max_len=max_block_table_len,
pad=0,
dtype=torch.int,
device=device,
)
assert max_query_len > 0, ("query_lens: {}".format(query_lens))

context_lens_tensor = torch.tensor(self.context_lens,
dtype=torch.int,
device=device)
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=device)
query_lens_tensor = torch.tensor(query_lens,
dtype=torch.long,
device=device)
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=device)
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=device)
torch.cumsum(seq_lens_tensor,
dim=0,
dtype=seq_start_loc.dtype,
out=seq_start_loc[1:])
torch.cumsum(query_lens_tensor,
dim=0,
dtype=query_start_loc.dtype,
out=query_start_loc[1:])

slot_mapping_tensor = torch.tensor(self.slot_mapping,
dtype=torch.long,
device=device)

return FlashAttentionMetadata(
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len,
query_start_loc=query_start_loc,
seq_start_loc=seq_start_loc,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
)


class FlashAttentionImpl(AttentionImpl):
"""
If the input tensors contain prompt tokens, the layout is as follows:
Expand Down
Loading