Skip to content

Commit

Permalink
[ROCm][AMD][Model]Adding alibi slopes support in ROCm triton flash at…
Browse files Browse the repository at this point in the history
…tention and naive flash attention (vllm-project#6043)

Co-authored-by: Hongxia Yang <[email protected]>
  • Loading branch information
gshtras and hongxiayang authored Jul 4, 2024
1 parent 3dd5070 commit 56b325e
Showing 1 changed file with 51 additions and 2 deletions.
53 changes: 51 additions & 2 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,37 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
return self._cached_decode_metadata


def _make_alibi_bias(alibi_slopes: torch.Tensor,
dtype: torch.dtype,
seq_lens: Optional[List[int]],
make_attn_mask: bool = True) -> List[torch.Tensor]:
attn_biases = []
if seq_lens:
for seq_len in seq_lens:
bias = torch.arange(seq_len, dtype=dtype)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(seq_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
bias = bias[None, :] - bias[:, None]

num_heads = alibi_slopes.shape[0]
bias = bias[None, :].repeat(
(num_heads, 1, 1)).to(alibi_slopes.device)
bias.mul_(alibi_slopes[:, None, None])
if make_attn_mask:
inf_mask = torch.empty(
(1, seq_len, seq_len),
dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1).to(
alibi_slopes.device)
attn_biases.append((bias + inf_mask).to(dtype))
else:
attn_biases.append(bias.to(dtype))

return attn_biases


class ROCmFlashAttentionImpl(AttentionImpl):
"""
If the input tensors contain prompt tokens, the layout is as follows:
Expand Down Expand Up @@ -324,7 +355,14 @@ def forward(
# triton attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
attn_masks = None
if self.use_triton_flash_attn:
if self.alibi_slopes is not None:
attn_masks = _make_alibi_bias(
self.alibi_slopes,
query.dtype,
attn_metadata.seq_lens,
make_attn_mask=False) # type: ignore
out, _ = self.attn_func(
query,
key,
Expand All @@ -336,12 +374,20 @@ def forward(
prefill_meta.max_prefill_seq_len,
True,
self.scale,
attn_masks[0][None]
if attn_masks is not None else None,
)
elif self.use_naive_attn:
if self.num_kv_heads != self.num_heads:
# Interleave for MQA workaround.
key = self.repeat_kv(key, self.num_queries_per_kv)
value = self.repeat_kv(value, self.num_queries_per_kv)
if self.alibi_slopes is not None:
attn_masks = _make_alibi_bias(
self.alibi_slopes,
query.dtype,
attn_metadata.seq_lens,
make_attn_mask=True) # type: ignore
query = query.movedim(0, query.dim() - 2)
key = key.movedim(0, key.dim() - 2)
value = value.movedim(0, value.dim() - 2)
Expand All @@ -355,6 +401,7 @@ def forward(
self.num_heads,
self.head_size,
self.scale,
attn_masks,
)
else:
out = self.attn_func(
Expand Down Expand Up @@ -418,13 +465,14 @@ def _sdpa_attention(
num_heads: int,
head_size: int,
scale: float,
attn_masks: Optional[List[torch.Tensor]] = None,
) -> torch.Tensor:
start = 0
output = torch.empty((num_tokens, num_heads, head_size),
dtype=query.dtype,
device=query.device)

for seq_len in seq_lens:
for i, seq_len in enumerate(seq_lens):
end = start + seq_len
with torch.backends.cuda.sdp_kernel(enable_math=True,
enable_flash=False,
Expand All @@ -434,7 +482,8 @@ def _sdpa_attention(
key[:, start:end, :],
value[:, start:end, :],
dropout_p=0.0,
is_causal=True,
is_causal=attn_masks is None,
attn_mask=attn_masks[i] if attn_masks else None,
scale=scale).movedim(query.dim() - 2, 0)
output[start:end, :, :] = sub_out
start = end
Expand Down

0 comments on commit 56b325e

Please sign in to comment.