Skip to content

Commit

Permalink
[Misc] Add numpy implementation of compute_slot_mapping (vllm-proje…
Browse files Browse the repository at this point in the history
  • Loading branch information
Yard1 authored and fialhocoelho committed Aug 22, 2024
1 parent b8c3524 commit daec291
Showing 1 changed file with 41 additions and 12 deletions.
53 changes: 41 additions & 12 deletions vllm/attention/backends/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Attention backend utils"""
from typing import TYPE_CHECKING, Dict, List, Type, TypeVar, Union

import numpy as np
import torch

from vllm.attention import AttentionMetadata, AttentionMetadataBuilder
Expand All @@ -13,6 +14,10 @@

PAD_SLOT_ID = -1

# Switch to numpy implementation of compute_slot_mapping
# if we have at least this many elements. Could be tuned further.
_COMPUTE_SLOT_MAPPING_NUMPY_NUMEL = 256

if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUBuilder

Expand Down Expand Up @@ -46,6 +51,29 @@ def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int,
return start_idx


def _compute_slot_mapping_python(slot_mapping: List[int],
block_table: List[int], range_start: int,
range_end: int, block_size: int):
for i in range(range_start, range_end):
block_number = block_table[i // block_size]
block_offset = i % block_size
slot = block_number * block_size + block_offset
slot_mapping.append(slot)


def _compute_slot_mapping_numpy(slot_mapping: List[int],
block_table: List[int], range_start: int,
range_end: int, block_size: int):
block_table_array = np.array(block_table)
idx = np.arange(range_start, range_end)
block_offset = idx % block_size
idx //= block_size
seq_slot_mapping_array = block_table_array[idx]
seq_slot_mapping_array *= block_size
seq_slot_mapping_array += block_offset
slot_mapping.extend(seq_slot_mapping_array)


def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int],
seq_id: int, seq_len: int, context_len: int,
start_idx: int, block_size: int,
Expand All @@ -67,21 +95,22 @@ def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int],
# sliding window is 8, and block size is 4, the first two
# tokens are masked and the slot mapping will be
# [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
block_table = block_tables[seq_id]
padding_mask_len = max(0, start_idx - context_len)
slot_mapping.extend([PAD_SLOT_ID] * padding_mask_len)

def add_slot(i):
block_number = block_table[i // block_size]
block_offset = i % block_size
slot = block_number * block_size + block_offset
slot_mapping.append(slot)
range_start = max(start_idx, context_len)
range_end = seq_len
numel = range_end - range_start
block_table = block_tables[seq_id]

if start_idx == 0 and (seq_len - context_len) == 1:
# Optimization for common-case of decoding next token
add_slot(seq_len - 1)
# numpy implementation will be faster than python if we have
# many elements, otherwise it will be slower.
if numel < _COMPUTE_SLOT_MAPPING_NUMPY_NUMEL:
_compute_slot_mapping_python(slot_mapping, block_table, range_start,
range_end, block_size)
else:
slot_mapping.extend([PAD_SLOT_ID] * max(0, start_idx - context_len))
for i in range(max(start_idx, context_len), seq_len):
add_slot(i)
_compute_slot_mapping_numpy(slot_mapping, block_table, range_start,
range_end, block_size)


TAttentionMetadata = TypeVar("TAttentionMetadata", bound='AttentionMetadata')
Expand Down

0 comments on commit daec291

Please sign in to comment.