From 0826a5f5463c2a56ca54d1ef342acdad3d877c00 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 7 Jan 2025 00:11:28 +0800 Subject: [PATCH] [Kernel] Move attn_type to Attention.__init__() (#11690) Signed-off-by: Chen Zhang --- tests/kernels/test_encoder_decoder_attn.py | 100 ++++++++++---------- tests/kernels/utils.py | 12 ++- vllm/attention/backends/abstract.py | 2 +- vllm/attention/backends/blocksparse_attn.py | 14 +-- vllm/attention/backends/flash_attn.py | 4 +- vllm/attention/backends/flashinfer.py | 15 ++- vllm/attention/backends/hpu_attn.py | 13 +-- vllm/attention/backends/ipex_attn.py | 12 +-- vllm/attention/backends/pallas.py | 13 +-- vllm/attention/backends/rocm_flash_attn.py | 14 +-- vllm/attention/backends/torch_sdpa.py | 4 +- vllm/attention/backends/xformers.py | 6 +- vllm/attention/layer.py | 37 ++------ vllm/model_executor/models/bart.py | 44 +++------ vllm/model_executor/models/bert.py | 10 +- vllm/model_executor/models/mllama.py | 11 +-- vllm/model_executor/models/qwen2.py | 35 ++++--- vllm/v1/attention/backends/flash_attn.py | 14 +-- 18 files changed, 159 insertions(+), 201 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index d943b048b7934..614674375786e 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -13,8 +13,7 @@ import torch from tests.kernels.utils import * -from vllm.attention import (Attention, AttentionBackend, AttentionMetadata, - AttentionType) +from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP from vllm.attention.selector import (_Backend, _cached_get_attn_backend, global_force_attn_backend_context_manager) @@ -64,6 +63,7 @@ class TestPoint(NamedTuple): max_dec_seq_len: int max_enc_seq_len: int num_blocks: int + attn_type: AttentionType class TestResources(NamedTuple): @@ -96,7 +96,6 @@ class TestResources(NamedTuple): ''' scale: float - attn_backend: AttentionBackend attn: Attention kv_cache: torch.Tensor @@ -129,16 +128,17 @@ class that Attention will automatically select when it is constructed. ''' scale = float(1.0 / (test_pt.head_size**0.5)) - attn_backend = make_backend(test_pt.backend_name) attn = Attention( test_pt.num_heads, test_pt.head_size, scale=scale, + prefix=f"{test_pt.attn_type}", + attn_type=test_pt.attn_type, ) if test_pt.num_blocks is None or test_pt.num_heads is None: # Caller does not require a KV cache return TestResources( - scale, attn_backend, attn, + scale, attn, torch.tensor([], dtype=torch.float32, device=CUDA_DEVICE)) # Construct KV cache @@ -148,7 +148,7 @@ class that Attention will automatically select when it is constructed. test_pt.block_size, device=CUDA_DEVICE, backend=test_pt.backend_name) - return TestResources(scale, attn_backend, attn, kv_cache) + return TestResources(scale, attn, kv_cache) def _encoder_attn_setup( @@ -193,6 +193,7 @@ def _encoder_attn_setup( _, max_q_seq_len, _, + _, ) = test_pt scale = test_rsrcs.scale @@ -301,6 +302,7 @@ def _decoder_attn_setup( max_q_seq_len, _, _, + _, ) = test_pt scale = test_rsrcs.scale @@ -488,6 +490,7 @@ def _enc_dec_cross_attn_setup_reuses_query( max_decoder_seq_len, max_encoder_seq_len, _, + _, ) = test_pt scale = test_rsrcs.scale @@ -622,7 +625,6 @@ def _run_encoder_attention_test( & attn_metadata ''' assert attn_metadata.num_decode_tokens == 0 - attn_type = AttentionType.ENCODER packed_qkv = encoder_test_params.packed_qkvo.packed_qkv assert packed_qkv is not None with set_forward_context(attn_metadata, vllm_config): @@ -635,14 +637,11 @@ def _run_encoder_attention_test( # is shaped as [num_tokens, hidden_size] and we can skip the reshape. reshaped_query = packed_qkv.query.view( -1, test_pt.num_heads * test_pt.head_size) - return attn.forward(reshaped_query, - packed_qkv.key, - packed_qkv.value, - torch.tensor([], - dtype=torch.float32, - device=packed_qkv.query.device), - attn_metadata, - attn_type=attn_type) + return attn.forward( + reshaped_query, packed_qkv.key, packed_qkv.value, + torch.tensor([], + dtype=torch.float32, + device=packed_qkv.query.device), attn_metadata) def _run_decoder_self_attention_test( @@ -675,7 +674,6 @@ def _run_decoder_self_attention_test( * Attention.forward() applied to packed_{query,key,value}, kv_cache & attn_metadata ''' - attn_type = AttentionType.DECODER attn = test_rsrcs.attn kv_cache = test_rsrcs.kv_cache packed_qkv = decoder_test_params.packed_qkvo.packed_qkv @@ -690,12 +688,8 @@ def _run_decoder_self_attention_test( # is shaped as [num_tokens, hidden_size] and we can skip the reshape. reshaped_query = packed_qkv.query.view( -1, test_pt.num_heads * test_pt.head_size) - return attn.forward(reshaped_query, - packed_qkv.key, - packed_qkv.value, - kv_cache, - attn_metadata, - attn_type=attn_type) + return attn.forward(reshaped_query, packed_qkv.key, packed_qkv.value, + kv_cache, attn_metadata) def _run_encoder_decoder_cross_attention_test( @@ -742,7 +736,6 @@ def _run_encoder_decoder_cross_attention_test( ''' assert decoder_test_params.packed_qkvo.packed_qkv is not None - attn_type = AttentionType.ENCODER_DECODER attn = test_rsrcs.attn kv_cache = test_rsrcs.kv_cache if cross_test_params is None: @@ -762,12 +755,8 @@ def _run_encoder_decoder_cross_attention_test( # is shaped as [num_tokens, hidden_size] and we can skip the reshape. reshaped_query = decoder_test_params.packed_qkvo.packed_qkv.query.view( -1, test_pt.num_heads * test_pt.head_size) - return attn.forward(reshaped_query, - key, - value, - kv_cache, - attn_metadata, - attn_type=attn_type) + return attn.forward(reshaped_query, key, value, kv_cache, + attn_metadata) @pytest.fixture(autouse=True) @@ -839,7 +828,7 @@ def test_encoder_only( # is not part of this test test_pt = TestPoint(num_heads, head_size, attn_backend.name, batch_size, block_size, max_dec_seq_len, - max_enc_seq_len, 4096) + max_enc_seq_len, 4096, AttentionType.ENCODER) # Attention scale factor, attention backend instance, attention wrapper # instance, KV cache init @@ -855,7 +844,7 @@ def test_encoder_only( # Shared prefill metadata structure prephase_attn_metadata: AttentionMetadata = make_test_metadata( - test_rsrcs.attn_backend, + attn_backend, True, None, decoder_test_params=None, @@ -961,20 +950,29 @@ def test_e2e_enc_dec_attn( # Note: KV cache size of 4096 is arbitrary & chosen intentionally # to be more than necessary, since exceeding the kv cache size # is not part of this test - test_pt = TestPoint(num_heads, head_size, attn_backend.name, - batch_size, block_size, max_dec_seq_len, - max_enc_seq_len, 4096) + enc_test_pt = TestPoint(num_heads, head_size, attn_backend.name, + batch_size, block_size, max_dec_seq_len, + max_enc_seq_len, 4096, AttentionType.ENCODER) + enc_dec_test_pt = TestPoint(num_heads, head_size, attn_backend.name, + batch_size, block_size, max_dec_seq_len, + max_enc_seq_len, 4096, + AttentionType.ENCODER_DECODER) + dec_test_pt = TestPoint(num_heads, head_size, attn_backend.name, + batch_size, block_size, max_dec_seq_len, + max_enc_seq_len, 4096, AttentionType.DECODER) # Attention scale factor, attention backend instance, attention wrapper # instance, KV cache init vllm_config = VllmConfig() with set_current_vllm_config(vllm_config): - test_rsrcs = _make_test_resources(test_pt) + enc_test_rsrcs = _make_test_resources(enc_test_pt) + enc_dec_test_rsrcs = _make_test_resources(enc_dec_test_pt) + dec_test_rsrcs = _make_test_resources(dec_test_pt) # Construct encoder attention test params (only used # during prefill) - enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs) + enc_test_params = _encoder_attn_setup(enc_test_pt, enc_test_rsrcs) # Construct Decoder self-attention prefill-phase & decode-phase # test params, including query/key/value tensors, decoder self-attention @@ -987,7 +985,7 @@ def test_e2e_enc_dec_attn( prephase_dec_test_params, decphase_dec_test_params, cross_block_base_addr, - ) = _decoder_attn_setup(test_pt, test_rsrcs) + ) = _decoder_attn_setup(dec_test_pt, dec_test_rsrcs) # Construct encoder/decoder cross-attention prefill-phase # & decode-phase test params, including key/value tensors, @@ -1000,14 +998,14 @@ def test_e2e_enc_dec_attn( dec_qkv, enc_test_params, prephase_dec_test_params, - test_pt, - test_rsrcs, + enc_dec_test_pt, + enc_dec_test_rsrcs, block_base_addr=cross_block_base_addr) # Shared prefill metadata structure assert prephase_dec_test_params.packed_qkvo.packed_qkv is not None prephase_attn_metadata: AttentionMetadata = make_test_metadata( - test_rsrcs.attn_backend, + attn_backend, True, prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens, decoder_test_params=prephase_dec_test_params, @@ -1017,10 +1015,10 @@ def test_e2e_enc_dec_attn( # PREFILL: encoder attention - enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn, + enc_pckd_act_out = _run_encoder_attention_test(enc_test_rsrcs.attn, enc_test_params, prephase_attn_metadata, - test_pt=test_pt, + test_pt=enc_test_pt, vllm_config=vllm_config) # - Is encoder attention result correct? @@ -1030,10 +1028,10 @@ def test_e2e_enc_dec_attn( # PREFILL: decoder self-attention test prephase_dec_pckd_act_out = _run_decoder_self_attention_test( - test_rsrcs, + dec_test_rsrcs, prephase_dec_test_params, prephase_attn_metadata, - test_pt=test_pt, + test_pt=dec_test_pt, vllm_config=vllm_config) # - Is prefill decoder self-attention correct? @@ -1044,11 +1042,11 @@ def test_e2e_enc_dec_attn( # PREFILL: encoder/decoder cross-attention test prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test( - test_rsrcs, + enc_dec_test_rsrcs, prephase_dec_test_params, prephase_cross_test_params, prephase_attn_metadata, - test_pt=test_pt, + test_pt=enc_dec_test_pt, vllm_config=vllm_config) # - Is prefill encoder/decoder cross-attention correct? @@ -1059,7 +1057,7 @@ def test_e2e_enc_dec_attn( # DECODE: build decode-phase attention metadata decphase_attn_metadata: AttentionMetadata = make_test_metadata( - test_rsrcs.attn_backend, + attn_backend, False, dec_qkv.q_seq_lens, decoder_test_params=decphase_dec_test_params, @@ -1070,10 +1068,10 @@ def test_e2e_enc_dec_attn( # DECODE: decoder self-attention test decphase_dec_pckd_act_out = _run_decoder_self_attention_test( - test_rsrcs, + dec_test_rsrcs, decphase_dec_test_params, decphase_attn_metadata, - test_pt=test_pt, + test_pt=dec_test_pt, vllm_config=vllm_config) # - Is decode-phase decoder self-attention correct? @@ -1084,11 +1082,11 @@ def test_e2e_enc_dec_attn( # DECODE: encoder/decoder cross-attention test decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test( - test_rsrcs, + enc_dec_test_rsrcs, decphase_dec_test_params, None, decphase_attn_metadata, - test_pt=test_pt, + test_pt=enc_dec_test_pt, vllm_config=vllm_config) # - Is decode-phase encoder/decoder cross-attention correct? diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index e7865fb2500ef..848eea7f54cab 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -13,6 +13,7 @@ from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType from vllm.model_executor.layers.activation import SiluAndMul +from vllm.platforms.interface import _Backend from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, STR_XFORMERS_ATTN_VAL, make_tensor_with_pad) @@ -790,7 +791,7 @@ def make_block_tables_slot_mapping( def make_test_metadata( - attn_backend: AttentionBackend, + attn_backend: _Backend, is_prompt: bool, seq_lens: Optional[List[int]], decoder_test_params: Optional[PhaseTestParameters], @@ -815,7 +816,7 @@ def make_test_metadata( Arguments: - * attn_backend: Backend for sourcing attention kernels + * attn_backend_name: Backend for sourcing attention kernels * is_prompt: prefill if True, o/w decode * seq_lens: list of token counts for each sequence * decoder_test_params: decoder self-attention test params; @@ -882,6 +883,8 @@ def make_test_metadata( # (kv_mmap) cross_kv_mmap = cross_test_params.kv_mmap + attn_backend_obj = make_backend(attn_backend.name) + if is_prompt: # Prefill-phase scenario @@ -902,8 +905,7 @@ def make_test_metadata( context_lens, encoder_seq_lens, device=device) - - return attn_backend.make_metadata( + return attn_backend_obj.make_metadata( num_prefills=num_prefills, slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping), multi_modal_placeholder_index_maps=None, @@ -952,7 +954,7 @@ def make_test_metadata( encoder_seq_lens, device=device) - return attn_backend.make_metadata( + return attn_backend_obj.make_metadata( num_prefills=num_prefills, slot_mapping=kv_mmap.slot_mapping, multi_modal_placeholder_index_maps=None, diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index aed04361e5fb4..f5dcaea79af93 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -233,6 +233,7 @@ def __init__( kv_cache_dtype: str = "auto", blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, ) -> None: raise NotImplementedError @@ -246,7 +247,6 @@ def forward( attn_metadata: T, k_scale: float = 1.0, v_scale: float = 1.0, - attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index 99cb84346d84e..7089d59392c36 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -300,6 +300,7 @@ def __init__( kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, ) -> None: assert blocksparse_params is not None assert alibi_slopes is None, ValueError( @@ -350,6 +351,12 @@ def __init__( active_head_range=self.blocksparse_params.active_head_range, ) + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "BlocksparseFlashAttentionImpl") + def forward( self, query: torch.Tensor, @@ -359,7 +366,6 @@ def forward( attn_metadata: BlocksparseFlashAttentionMetadata, k_scale: float = 1.0, v_scale: float = 1.0, - attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -375,12 +381,6 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "BlocksparseFlashAttentionImpl") - num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index c69e12ad78c44..23ea244f07dfe 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -600,6 +600,7 @@ def __init__( kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, ) -> None: if blocksparse_params is not None: raise ValueError( @@ -627,6 +628,7 @@ def __init__( raise ValueError( f"Head size {head_size} is not supported by FlashAttention. " f"Supported head sizes are: {support_head_sizes}.") + self.attn_type = attn_type def forward( self, @@ -637,7 +639,6 @@ def forward( attn_metadata: FlashAttentionMetadata, k_scale: float = 1.0, v_scale: float = 1.0, - attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention. @@ -659,6 +660,7 @@ def forward( assert output is not None, "Output tensor must be provided." + attn_type = self.attn_type if (attn_type == AttentionType.ENCODER and (not attn_metadata.is_all_encoder_attn_metadata_set)): raise AttributeError("Encoder attention requires setting " diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index e367468d05d26..a11462b2068a5 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -748,6 +748,7 @@ def __init__( kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, ) -> None: self.num_heads = num_heads self.head_size = head_size @@ -764,6 +765,12 @@ def __init__( assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashInferImpl") + def forward( self, query: torch.Tensor, @@ -773,18 +780,10 @@ def forward( attn_metadata: FlashInferMetadata, k_scale: float = 1.0, v_scale: float = 1.0, - attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: # TODO: directly write to output tensor - - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "FlashInferImpl") - num_heads: int = self.num_heads head_size: int = self.head_size num_kv_heads: int = self.num_kv_heads diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index f90d15d4207e7..94a461e0c8c29 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -102,6 +102,7 @@ def __init__( kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, max_seq_len: int = 4096, + attn_type: str = AttentionType.DECODER, ) -> None: super(AttentionImpl, self).__init__() self.kv_cache_dtype = kv_cache_dtype @@ -143,6 +144,12 @@ def __init__( f"Head size {head_size} is not supported by PagedAttention. " f"Supported head sizes are: {suppored_head_sizes}.") + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "HPUAttentionImpl") + def forward( self, query: torch.Tensor, @@ -152,7 +159,6 @@ def forward( attn_metadata: HPUAttentionMetadata, k_scale: float = 1.0, v_scale: float = 1.0, - attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. @@ -166,11 +172,6 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "HPUAttentionImpl") batch_size, seq_len, hidden_size = query.shape _, seq_len_kv, _ = key.shape diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index 21949874bea47..da1d307daa517 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -115,6 +115,7 @@ def __init__( kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, ) -> None: if blocksparse_params is not None: raise ValueError( @@ -146,6 +147,11 @@ def __init__( raise NotImplementedError( "IPEX backend does not support FP8 KV cache. " "Please use xFormers backend instead.") + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "IpexAttnBackendImpl") def split_kv_cache( self, @@ -172,7 +178,6 @@ def forward( attn_metadata: IpexAttnMetadata, # type: ignore k_scale: float = 1.0, v_scale: float = 1.0, - attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with IPEX varlen_attention and PagedAttention. @@ -189,11 +194,6 @@ def forward( shape = [num_tokens, num_heads * head_size] """ assert k_scale == 1.0 and v_scale == 1.0 - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "IpexAttnBackendImpl") num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index 9809aed0e66f9..2ac492dd8ae54 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -100,6 +100,7 @@ def __init__( kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, ) -> None: self.num_heads = num_heads self.head_size = head_size @@ -141,6 +142,12 @@ def __init__( # megacore mode will be None. self.megacore_mode = "batch" + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "PallasAttentionBackendImpl") + def forward( self, query: torch.Tensor, @@ -150,7 +157,6 @@ def forward( attn_metadata: PallasMetadata, k_scale: float = 1.0, v_scale: float = 1.0, - attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with Pallas attention. @@ -168,11 +174,6 @@ def forward( shape = [batch_size, seq_len, num_heads * head_size] """ assert k_scale == 1.0 and v_scale == 1.0 - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "PallasAttentionBackendImpl") batch_size, seq_len, hidden_size = query.shape query = query.view(batch_size, seq_len, self.num_heads, self.head_size) key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index d43c15b661ef7..a91a5af5c3d58 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -338,6 +338,7 @@ def __init__( kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, ) -> None: if blocksparse_params is not None: raise ValueError( @@ -397,6 +398,12 @@ def __init__( self.attn_func = _sdpa_attention logger.debug("Using naive attention in ROCmBackend") + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "ROCmFlashAttentionImpl") + def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" tokens, n_kv_heads, head_dim = x.shape @@ -414,7 +421,6 @@ def forward( attn_metadata: ROCmFlashAttentionMetadata, k_scale: float = 1.0, v_scale: float = 1.0, - attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -432,12 +438,6 @@ def forward( """ # Reminder: Please update docs/source/features/compatibility_matrix.md # If the feature combo become valid - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "ROCmFlashAttentionImpl") - num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 0cff6f5952aba..c14f7754596dd 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -390,6 +390,7 @@ def __init__( kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, ) -> None: if blocksparse_params is not None: raise ValueError( @@ -421,6 +422,7 @@ def __init__( raise NotImplementedError( "Torch SDPA backend does not support FP8 KV cache. " "Please use xFormers backend instead.") + self.attn_type = attn_type def forward( self, @@ -431,7 +433,6 @@ def forward( attn_metadata: TorchSDPAMetadata, # type: ignore k_scale: float = 1.0, v_scale: float = 1.0, - attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with torch SDPA and PagedAttention. @@ -448,6 +449,7 @@ def forward( shape = [num_tokens, num_heads * head_size] """ assert k_scale == 1.0 and v_scale == 1.0 + attn_type = self.attn_type if (attn_type == AttentionType.ENCODER and (not attn_metadata.is_all_encoder_attn_metadata_set)): raise AttributeError("Encoder attention requires setting " diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 3e59b3603d2c6..694c7cc1bc36a 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -379,6 +379,7 @@ def __init__( kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, ) -> None: if blocksparse_params is not None: raise ValueError( @@ -405,6 +406,8 @@ def __init__( f"Head size {head_size} is not supported by PagedAttention. " f"Supported head sizes are: {suppored_head_sizes}.") + self.attn_type = attn_type + def forward( self, query: torch.Tensor, @@ -414,7 +417,6 @@ def forward( attn_metadata: "XFormersMetadata", k_scale: float = 1.0, v_scale: float = 1.0, - attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. @@ -468,7 +470,7 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ - + attn_type = self.attn_type # Check that appropriate attention metadata attributes are # selected for the desired attention type if (attn_type == AttentionType.ENCODER diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 69b6d1e4648df..f1b3598e60b54 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -41,6 +41,7 @@ def __init__( logits_soft_cap: Optional[float] = None, per_layer_sliding_window: Optional[int] = None, prefix: str = "", + attn_type: str = AttentionType.DECODER, ) -> None: super().__init__() if per_layer_sliding_window is not None: @@ -96,7 +97,7 @@ def __init__( impl_cls = attn_backend.get_impl_cls() self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, - blocksparse_params, logits_soft_cap) + blocksparse_params, logits_soft_cap, attn_type) self.num_heads = num_heads self.head_size = head_size self.num_kv_heads = num_kv_heads @@ -119,6 +120,7 @@ def __init__( raise ValueError(f"Duplicate layer name: {prefix}") compilation_config.static_forward_context[prefix] = self self.layer_name = prefix + self.attn_type = attn_type def forward( self, @@ -127,18 +129,12 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, - attn_type: str = AttentionType.DECODER, ) -> torch.Tensor: if self.use_direct_call: - return self.impl.forward(query, - key, - value, - kv_cache, - attn_metadata, - self._k_scale, - self._v_scale, - attn_type=attn_type) + return self.impl.forward(query, key, value, kv_cache, + attn_metadata, self._k_scale, + self._v_scale) elif self.use_output: output = torch.empty_like(query) hidden_size = query.size(-1) @@ -152,13 +148,11 @@ def forward( if value is not None: value = value.view(-1, self.num_kv_heads, self.head_size) torch.ops.vllm.unified_attention_with_output( - query, key, value, output, kv_cache, attn_type, - self.layer_name) + query, key, value, output, kv_cache, self.layer_name) return output.view(-1, hidden_size) else: return torch.ops.vllm.unified_attention(query, key, value, - kv_cache, attn_type, - self.layer_name) + kv_cache, self.layer_name) def extra_repr(self) -> str: s = f"head_size={self.impl.head_size}" # type: ignore @@ -237,20 +231,13 @@ def unified_attention( key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, - attn_type: str, layer_name: str, ) -> torch.Tensor: forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.dynamic_forward_context self = forward_context.static_forward_context[layer_name] - return self.impl.forward(query, - key, - value, - kv_cache, - attn_metadata, - self._k_scale, - self._v_scale, - attn_type=attn_type) + return self.impl.forward(query, key, value, kv_cache, attn_metadata, + self._k_scale, self._v_scale) def unified_attention_fake( @@ -258,7 +245,6 @@ def unified_attention_fake( key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, - attn_type: str, layer_name: str, ) -> torch.Tensor: return torch.empty_like(query).contiguous() @@ -279,7 +265,6 @@ def unified_attention_with_output( value: torch.Tensor, output: torch.Tensor, kv_cache: torch.Tensor, - attn_type: str, layer_name: str, ) -> None: forward_context: ForwardContext = get_forward_context() @@ -292,7 +277,6 @@ def unified_attention_with_output( attn_metadata, self._k_scale, self._v_scale, - attn_type=attn_type, output=output) @@ -302,7 +286,6 @@ def unified_attention_with_output_fake( value: torch.Tensor, output: torch.Tensor, kv_cache: torch.Tensor, - attn_type: str, layer_name: str, ) -> None: return diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index 3776490cb3465..57eb5adc82d5b 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -71,12 +71,8 @@ def __init__(self, num_embeddings: int, embedding_dim: int): def forward( self, positions: torch.Tensor, - attn_type: AttentionType, ) -> torch.Tensor: """`input_ids' shape is expected to be [bsz x seqlen].""" - - assert attn_type != AttentionType.ENCODER_DECODER - return super().forward(positions + self.offset) @@ -180,7 +176,8 @@ def __init__( num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, - prefix=f"{prefix}.attn") + prefix=f"{prefix}.attn", + attn_type=AttentionType.ENCODER) def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata) -> torch.Tensor: @@ -189,12 +186,7 @@ def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - attn_output = self.attn(q, - k, - v, - kv_cache, - attn_metadata, - attn_type=AttentionType.ENCODER) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.out_proj(attn_output) return output @@ -264,7 +256,8 @@ def __init__( num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, - prefix=f"{prefix}.attn") + prefix=f"{prefix}.attn", + attn_type=AttentionType.DECODER) def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata) -> torch.Tensor: @@ -273,12 +266,7 @@ def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - attn_output = self.attn(q, - k, - v, - kv_cache, - attn_metadata, - attn_type=AttentionType.DECODER) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.out_proj(attn_output) return output @@ -348,7 +336,8 @@ def __init__( num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, - prefix=f"{prefix}.attn") + prefix=f"{prefix}.attn", + attn_type=AttentionType.ENCODER_DECODER) def forward( self, @@ -372,12 +361,7 @@ def forward( _, k, v = qkv_enc.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - attn_output = self.attn(q, - k, - v, - kv_cache, - attn_metadata, - attn_type=AttentionType.ENCODER_DECODER) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.out_proj(attn_output) return output @@ -644,10 +628,7 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, # retrieve input_ids and inputs_embeds inputs_embeds = self.embed_tokens(input_ids) - embed_pos = self.embed_positions( - positions, - AttentionType.ENCODER, - ) + embed_pos = self.embed_positions(positions) embed_pos = embed_pos.to(inputs_embeds.device) hidden_states = inputs_embeds + embed_pos @@ -734,10 +715,7 @@ def forward(self, decoder_input_ids: torch.Tensor, inputs_embeds = self.embed_tokens(decoder_input_ids) # embed positions - embed_pos = self.embed_positions( - decoder_positions, - AttentionType.DECODER, - ) + embed_pos = self.embed_positions(decoder_positions) embed_pos = embed_pos.to(inputs_embeds.device) hidden_states = inputs_embeds + embed_pos diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index c1d47b1bc9bcd..4be136543de15 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -238,7 +238,8 @@ def __init__( num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, - prefix=f"{prefix}.attn") + prefix=f"{prefix}.attn", + attn_type=AttentionType.ENCODER_ONLY) def forward( self, @@ -248,12 +249,7 @@ def forward( ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - output = self.attn(q, - k, - v, - kv_cache, - attn_metadata, - attn_type=AttentionType.ENCODER_ONLY) + output = self.attn(q, k, v, kv_cache, attn_metadata) return output diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 6536f9807730c..c5046e06edecb 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -770,6 +770,7 @@ def __init__( self.scaling, self.num_local_key_value_heads, prefix=f"{prefix}.attn", + attn_type=AttentionType.ENCODER_DECODER, ) def forward( @@ -805,13 +806,9 @@ def forward( kv_range_for_decode, attn_metadata) else: - output = self.attn(q.view(-1, - self.num_local_heads * self.head_dim), - k, - v, - kv_cache, - attn_metadata, - attn_type=AttentionType.ENCODER_DECODER) + output = self.attn( + q.view(-1, self.num_local_heads * self.head_dim), k, v, + kv_cache, attn_metadata) out, _ = self.o_proj(output) return out diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 88f4ea4352726..01745b5fd53e1 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -107,7 +107,8 @@ def __init__(self, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, rope_scaling: Optional[Tuple] = None, - prefix: str = "") -> None: + prefix: str = "", + attn_type: str = AttentionType.DECODER) -> None: super().__init__() self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() @@ -160,7 +161,8 @@ def __init__(self, num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, - prefix=f"{prefix}.attn") + prefix=f"{prefix}.attn", + attn_type=attn_type) def forward( self, @@ -168,17 +170,11 @@ def forward( hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, - attn_type: str = AttentionType.DECODER, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, - k, - v, - kv_cache, - attn_metadata, - attn_type=attn_type) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.o_proj(attn_output) return output @@ -197,6 +193,16 @@ def __init__( # Requires transformers > 4.32.0 rope_theta = getattr(config, "rope_theta", 1000000) rope_scaling = getattr(config, "rope_scaling", None) + + # By default, Qwen2 uses causal attention as it is a decoder-only model. + # You can override the HF config with `is_causal=False` to enable + # bidirectional attention, which is used in some embedding models + # (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct) + if getattr(config, "is_causal", True): + attn_type = AttentionType.DECODER + else: + attn_type = AttentionType.ENCODER_ONLY + self.self_attn = Qwen2Attention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -207,6 +213,7 @@ def __init__( quant_config=quant_config, rope_scaling=rope_scaling, prefix=f"{prefix}.self_attn", + attn_type=attn_type, ) self.mlp = Qwen2MLP( hidden_size=self.hidden_size, @@ -220,15 +227,6 @@ def __init__( self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - # By default, Qwen2 uses causal attention as it is a decoder-only model. - # You can override the HF config with `is_causal=False` to enable - # bidirectional attention, which is used in some embedding models - # (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct) - if getattr(config, "is_causal", True): - self._attn_type = AttentionType.DECODER - else: - self._attn_type = AttentionType.ENCODER_ONLY - def forward( self, positions: torch.Tensor, @@ -249,7 +247,6 @@ def forward( hidden_states=hidden_states, kv_cache=kv_cache, attn_metadata=attn_metadata, - attn_type=self._attn_type, ) # Fully Connected diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 65002f1ad70c7..b02bc9ffde538 100644 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -89,6 +89,7 @@ def __init__( kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, + attn_type: AttentionType = AttentionType.DECODER, ) -> None: if blocksparse_params is not None: raise ValueError( @@ -119,6 +120,12 @@ def __init__( f"Head size {head_size} is not supported by FlashAttention. " f"Supported head sizes are: {support_head_sizes}.") + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashAttentionImpl") + def forward( self, query: torch.Tensor, @@ -128,7 +135,6 @@ def forward( attn_metadata: FlashAttentionMetadata, k_scale: float = 1.0, v_scale: float = 1.0, - attn_type: AttentionType = AttentionType.DECODER, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention. @@ -142,12 +148,6 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "FlashAttentionImpl") - # NOTE(woosuk): FlashAttention does not support FP8 KV cache. assert k_scale == 1.0 and v_scale == 1.0, ( "key/v_scale is not supported in FlashAttention.")