Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Refactor _prepare_model_input_tensors - take 2 #6164

Merged
merged 4 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion tests/worker/test_model_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch

from vllm.attention import AttentionMetadata
from vllm.attention import AttentionMetadata, AttentionMetadataBuilder
from vllm.attention.backends.abstract import AttentionBackend
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.pooling_metadata import PoolingMetadata
Expand All @@ -26,6 +26,10 @@ def get_impl_cls():
def get_metadata_cls() -> Type["AttentionMetadata"]:
return AttentionMetadata

@staticmethod
def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
raise AttentionMetadataBuilder

@staticmethod
def get_kv_cache_shape(
num_blocks: int,
Expand Down
4 changes: 3 additions & 1 deletion vllm/attention/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata)
AttentionMetadata,
AttentionMetadataBuilder)
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend

__all__ = [
"Attention",
"AttentionBackend",
"AttentionMetadata",
"AttentionMetadataBuilder",
"Attention",
"get_attn_backend",
]
45 changes: 43 additions & 2 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, fields
from enum import Enum, auto
from typing import (Any, Dict, Generic, List, Optional, Set, Tuple, Type,
TypeVar)
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
Tuple, Type, TypeVar)

import torch

if TYPE_CHECKING:
from vllm.sequence import SequenceGroupMetadata
from vllm.worker.model_runner_base import ModelRunnerInputBuilderBase


class AttentionType(Enum):
DECODER = auto() # Decoder attention between previous layer Q/K/V
Expand Down Expand Up @@ -35,6 +39,16 @@ def get_metadata_cls() -> Type["AttentionMetadata"]:
def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
return cls.get_metadata_cls()(*args, **kwargs)

@staticmethod
@abstractmethod
def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
raise NotImplementedError

@classmethod
def make_metadata_builder(cls, *args,
**kwargs) -> "AttentionMetadataBuilder":
return cls.get_builder_cls()(*args, **kwargs)

@staticmethod
@abstractmethod
def get_kv_cache_shape(
Expand Down Expand Up @@ -110,6 +124,33 @@ def asdict_zerocopy(self,
T = TypeVar("T", bound=AttentionMetadata)


class AttentionMetadataBuilder(ABC, Generic[T]):
"""Abstract class for attention metadata builders."""

@abstractmethod
def __init__(self, input_builder) -> None:
raise NotImplementedError

@abstractmethod
def add_seq_group(self, seq_group_metadata: "SequenceGroupMetadata",
token_lens: List[int], seq_lens: List[int],
curr_seq_lens: List[int], query_lens: List[int],
context_lens: List[int],
curr_sliding_window_blocks: List[int],
prefix_cache_hit: bool, chunked_prefill_enabled: bool):
"""Add a sequence group to the metadata and update
corresponding fields (in Python objects).
"""
raise NotImplementedError

@abstractmethod
def build(self, runner: "ModelRunnerInputBuilderBase", seq_lens: List[int],
query_lens: List[int], cuda_graph_pad_size: int,
batch_size: int) -> T:
"""Build attention metadata with on-device tensors."""
raise NotImplementedError


class AttentionImpl(ABC, Generic[T]):

@abstractmethod
Expand Down
11 changes: 11 additions & 0 deletions vllm/attention/backends/blocksparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonMetadataBuilder
from vllm.attention.ops.blocksparse_attention.interface import (
LocalStridedBlockSparseAttn, get_head_sliding_step)
from vllm.attention.ops.paged_attn import PagedAttention
Expand Down Expand Up @@ -93,6 +94,10 @@ def get_impl_cls() -> Type["BlocksparseFlashAttentionImpl"]:
def get_metadata_cls() -> Type["AttentionMetadata"]:
return BlocksparseFlashAttentionMetadata

@staticmethod
def get_builder_cls() -> Type["BlocksparseFlashAttentionMetadataBuilder"]:
return BlocksparseFlashAttentionMetadataBuilder

@staticmethod
def get_kv_cache_shape(
num_blocks: int,
Expand Down Expand Up @@ -244,6 +249,12 @@ def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]:
return self._cached_decode_metadata


class BlocksparseFlashAttentionMetadataBuilder(
CommonMetadataBuilder[BlocksparseFlashAttentionMetadata]):

_metadata_cls = BlocksparseFlashAttentionMetadata


class BlocksparseFlashAttentionImpl(AttentionImpl):
"""
If the input tensors contain prompt tokens, the layout is as follows:
Expand Down
183 changes: 181 additions & 2 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,24 @@
"""Attention layer with FlashAttention."""
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type

import torch
from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache

from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
AttentionMetadata,
AttentionMetadataBuilder,
AttentionType)
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.sequence import SequenceGroupMetadata
from vllm.utils import make_tensor_with_pad

if TYPE_CHECKING:
from vllm.worker.model_runner import (GPUModelRunnerBase,
ModelInputForGPUBuilder)


class FlashAttentionBackend(AttentionBackend):
Expand All @@ -28,6 +39,10 @@ def get_impl_cls() -> Type["FlashAttentionImpl"]:
def get_metadata_cls() -> Type["AttentionMetadata"]:
return FlashAttentionMetadata

@staticmethod
def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]:
return FlashAttentionMetadataBuilder

@staticmethod
def get_kv_cache_shape(
num_blocks: int,
Expand Down Expand Up @@ -184,6 +199,170 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]:
return self._cached_decode_metadata


class FlashAttentionMetadataBuilder(
AttentionMetadataBuilder[FlashAttentionMetadata]):

def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.slot_mapping: List[int] = []
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
self.block_tables: List[List[int]] = []
self.curr_seq_lens: List[int] = []
self.num_prefills = 0
self.num_prefill_tokens = 0
self.num_decode_tokens = 0

self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
self.use_v2_block_manager = (
input_builder.scheduler_config.use_v2_block_manager)

def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
token_lens: List[int], seq_lens: List[int],
comaniac marked this conversation as resolved.
Show resolved Hide resolved
curr_seq_lens: List[int], query_lens: List[int],
context_lens: List[int],
curr_sliding_window_blocks: List[int],
prefix_cache_hit: bool, chunked_prefill_enabled: bool):
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
2. block table.
3. slot mapping.
"""
is_prompt = seq_group_metadata.is_prompt
block_tables = seq_group_metadata.block_tables

for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
curr_sliding_window_block) in zip(
seq_group_metadata.seq_data.keys(), token_lens, seq_lens,
curr_seq_lens, query_lens, context_lens,
curr_sliding_window_blocks):
self.context_lens.append(context_len)

if is_prompt:
self.num_prefills += 1
self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len)
else:
assert query_len == 1, (
"seq_len: {}, context_len: {}, query_len: {}".format(
seq_len, context_len, query_len))
self.num_decode_tokens += query_len
self.curr_seq_lens.append(curr_seq_len)

# Compute block table.
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table = []
if prefix_cache_hit:
# NOTE(woosuk): For flash-attn, the block table should
# include the entries for the incoming prefill tokens.
block_table = block_tables[seq_id]
elif ((chunked_prefill_enabled or not is_prompt)
and block_tables is not None):
block_table = block_tables[seq_id][-curr_sliding_window_block:]
self.block_tables.append(block_table)

# Compute slot mapping.
is_profile_run = is_block_tables_empty(block_tables)
start_idx = compute_slot_mapping_start_idx(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is compute_slot_mapping_start_idx splitted from compute_slot_mapping ? Seems like they are used together?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah the only reason is if we combine these 2 then the argument list would be more ugly:

compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
    seq_len, context_len,
    self.block_size,
    seq_group_metadata.block_tables,
    is_prompt, query_len, context_len, self.sliding_window, self.use_v2_block_manager)

I don't have preference tho so if you prefer to combine them I could do that. Plz let me know

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to combine it if you don't mind! Agreed there are too many args, but I think it is better than having 2 calls depending on each other (imo it is harder to use). But it is nit!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. I'll merge this one and do it in the next PR.

is_prompt, query_len, context_len, self.sliding_window,
self.use_v2_block_manager)
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
seq_len, context_len, start_idx,
self.block_size,
seq_group_metadata.block_tables)

def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens,
cuda_graph_pad_size: int, batch_size: int):
"""Build attention metadata with on-device tensors."""
device = runner.device
use_captured_graph = cuda_graph_pad_size != -1

logits_soft_cap = getattr(runner.model_config.hf_config,
"attn_logit_softcapping", None)
if logits_soft_cap is not None:
raise ValueError(
"Please use Flashinfer backend for models with logits_soft_cap"
" (i.e., Gemma-2). Otherwise, the output might be wrong."
" Set Flashinfer backend by "
"export VLLM_ATTENTION_BACKEND=FLASHINFER.")

max_query_len = max(query_lens)
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
max_decode_seq_len = max(self.curr_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens

if use_captured_graph:
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
self.block_tables.extend([] * cuda_graph_pad_size)
num_decode_tokens = batch_size + cuda_graph_pad_size

# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables = runner.graph_block_tables[:batch_size]
for i, block_table in enumerate(self.block_tables):
if block_table:
input_block_tables[i, :len(block_table)] = block_table
block_tables = torch.tensor(input_block_tables, device=device)
else:
max_block_table_len = max(
len(block_table) for block_table in self.block_tables)
block_tables = make_tensor_with_pad(
self.block_tables,
max_len=max_block_table_len,
pad=0,
dtype=torch.int,
device=device,
)
assert max_query_len > 0, ("query_lens: {}".format(query_lens))

context_lens_tensor = torch.tensor(self.context_lens,
dtype=torch.int,
device=device)
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=device)
query_lens_tensor = torch.tensor(query_lens,
dtype=torch.long,
device=device)
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=device)
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=device)
torch.cumsum(seq_lens_tensor,
dim=0,
dtype=seq_start_loc.dtype,
out=seq_start_loc[1:])
torch.cumsum(query_lens_tensor,
dim=0,
dtype=query_start_loc.dtype,
out=query_start_loc[1:])

slot_mapping_tensor = torch.tensor(self.slot_mapping,
dtype=torch.long,
device=device)

return FlashAttentionMetadata(
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len,
query_start_loc=query_start_loc,
seq_start_loc=seq_start_loc,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
)


class FlashAttentionImpl(AttentionImpl):
"""
If the input tensors contain prompt tokens, the layout is as follows:
Expand Down
Loading
Loading