diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 23ea244f07dfe..94e84dacc3a5b 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -400,6 +400,7 @@ def _add_seq_group( """ is_prompt = inter_data.is_prompt block_tables = inter_data.block_tables + slot_mappings = inter_data.slot_mappings for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, curr_sliding_window_block) in zip( @@ -444,11 +445,11 @@ def _add_seq_group( # Compute slot mapping. is_profile_run = is_block_tables_empty(block_tables) start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, - context_len, + self.block_size, self.sliding_window) compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, seq_len, context_len, start_idx, - self.block_size, inter_data.block_tables) + self.block_size, slot_mappings) def _get_graph_runner_block_tables( self, num_seqs: int, diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index a11462b2068a5..0c6d68cd09445 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -515,6 +515,7 @@ def _add_seq_group( """ is_prompt = inter_data.is_prompt block_tables = inter_data.block_tables + slot_mappings = inter_data.slot_mappings computed_block_nums = inter_data.computed_block_nums for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, @@ -556,11 +557,11 @@ def _add_seq_group( # Compute slot mapping. start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, - context_len, + self.block_size, self.sliding_window) compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, seq_len, context_len, start_idx, - self.block_size, inter_data.block_tables) + self.block_size, slot_mappings) # It is not necessary to add paged_kv_indices, paged_kv_indptr, # and paged_kv_last_page_len for profile run because we will diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 56cc43430301f..1534d192f7b2a 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -4,7 +4,6 @@ from itertools import accumulate from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union -import numpy as np import torch from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder, @@ -42,43 +41,21 @@ def is_block_tables_empty(block_tables: Union[None, Dict]): def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int, - context_len: int, sliding_window: int): + block_size: int, sliding_window: int): """ Compute the start index of slot mapping. """ start_idx = 0 if is_prompt and sliding_window is not None: - start_idx = max(0, query_len - sliding_window) + num_blocks = (sliding_window + block_size - 1) // block_size + start_idx = max(0, query_len - num_blocks * block_size) return start_idx -def _compute_slot_mapping_python(slot_mapping: List[int], - block_table: List[int], range_start: int, - range_end: int, block_size: int): - for i in range(range_start, range_end): - block_number = block_table[i // block_size] - block_offset = i % block_size - slot = block_number * block_size + block_offset - slot_mapping.append(slot) - - -def _compute_slot_mapping_numpy(slot_mapping: List[int], - block_table: List[int], range_start: int, - range_end: int, block_size: int): - block_table_array = np.array(block_table) - idx = np.arange(range_start, range_end) - block_offset = idx % block_size - idx //= block_size - seq_slot_mapping_array = block_table_array[idx] - seq_slot_mapping_array *= block_size - seq_slot_mapping_array += block_offset - slot_mapping.extend(seq_slot_mapping_array) - - def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int], seq_id: int, seq_len: int, context_len: int, start_idx: int, block_size: int, - block_tables: Dict[int, List[int]]): + slot_mappings: Dict[int, List[int]]): """ Compute slot mapping. """ @@ -95,23 +72,14 @@ def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int], # sliding_window). For example, if the prompt len is 10, # sliding window is 8, and block size is 4, the first two # tokens are masked and the slot mapping will be - # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. + # [-1, -1, 2, 3, 4, 5, 6, 7, 8, 9]. padding_mask_len = max(0, start_idx - context_len) slot_mapping.extend([PAD_SLOT_ID] * padding_mask_len) range_start = max(start_idx, context_len) range_end = seq_len - numel = range_end - range_start - block_table = block_tables[seq_id] - - # numpy implementation will be faster than python if we have - # many elements, otherwise it will be slower. - if numel < _COMPUTE_SLOT_MAPPING_NUMPY_NUMEL: - _compute_slot_mapping_python(slot_mapping, block_table, range_start, - range_end, block_size) - else: - _compute_slot_mapping_numpy(slot_mapping, block_table, range_start, - range_end, block_size) + seq_slot_mapping = slot_mappings[seq_id] + slot_mapping.extend(seq_slot_mapping[range_start:range_end]) TAttentionMetadata = TypeVar("TAttentionMetadata", bound='AttentionMetadata') @@ -145,6 +113,7 @@ def _add_seq_group( chunked_prefill_enabled: bool): is_prompt = inter_data.is_prompt block_tables = inter_data.block_tables + slot_mappings = inter_data.slot_mappings for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, curr_sliding_window_block) in zip( @@ -189,11 +158,11 @@ def _add_seq_group( # Compute slot mapping. is_profile_run = is_block_tables_empty(block_tables) start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, - context_len, + self.block_size, self.sliding_window) compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, seq_len, context_len, start_idx, - self.block_size, inter_data.block_tables) + self.block_size, slot_mappings) def build(self, seq_lens: List[int], query_lens: List[int], cuda_graph_pad_size: int, batch_size: int): diff --git a/vllm/core/block/block_table.py b/vllm/core/block/block_table.py deleted file mode 100644 index 90c1438efbd08..0000000000000 --- a/vllm/core/block/block_table.py +++ /dev/null @@ -1,396 +0,0 @@ -import math -from typing import List, Optional - -from vllm.core.block.common import BlockList -from vllm.core.block.interfaces import Block, DeviceAwareBlockAllocator -from vllm.utils import Device, cdiv, chunk_list - - -class BlockTable: - """A class to manage blocks for a specific sequence. - - The BlockTable maps a sequence of tokens to a list of blocks, where each - block represents a contiguous memory allocation for a portion of the - sequence. The blocks are managed by a DeviceAwareBlockAllocator, which is - responsible for allocating and freeing memory for the blocks. - - Args: - block_size (int): The maximum number of tokens that can be stored in a - single block. - block_allocator (DeviceAwareBlockAllocator): The block allocator used to - manage memory for the blocks. - _blocks (Optional[List[Block]], optional): An optional list of existing - blocks to initialize the BlockTable with. If not provided, an empty - BlockTable is created. - max_block_sliding_window (Optional[int], optional): The number of - blocks to keep around for each sequence. If None, all blocks - are kept (eg., when sliding window is not used). - It should at least fit the sliding window size of the model. - - Attributes: - _block_size (int): The maximum number of tokens that can be stored in a - single block. - _allocator (DeviceAwareBlockAllocator): The block allocator used to - manage memory for the blocks. - _blocks (Optional[List[Block]]): The list of blocks managed by this - BlockTable. - _num_full_slots (int): The number of tokens currently stored in the - blocks. - """ - - def __init__( - self, - block_size: int, - block_allocator: DeviceAwareBlockAllocator, - _blocks: Optional[List[Block]] = None, - max_block_sliding_window: Optional[int] = None, - ): - self._block_size = block_size - self._allocator = block_allocator - if _blocks is None: - _blocks = [] - self._blocks: BlockList = BlockList(_blocks) - - self._max_block_sliding_window = max_block_sliding_window - self._num_full_slots = self._get_num_token_ids() - - @staticmethod - def get_num_required_blocks(token_ids: List[int], - block_size: int, - num_lookahead_slots: int = 0) -> int: - """Calculates the minimum number of blocks required to store a given - sequence of token IDs along with any look-ahead slots that may be - required (like in multi-step + chunked-prefill). - - This assumes worst-case scenario, where every block requires a new - allocation (e.g. ignoring prefix caching). - - Args: - token_ids (List[int]): The sequence of token IDs to be stored. - block_size (int): The maximum number of tokens that can be stored in - a single block. - num_lookahead_slots (int): look-ahead slots that the sequence may - require. - - Returns: - int: The minimum number of blocks required to store the given - sequence of token IDs along with any required look-ahead slots. - """ - return cdiv(len(token_ids) + num_lookahead_slots, block_size) - - def allocate(self, - token_ids: List[int], - device: Device = Device.GPU, - extra_hash: Optional[int] = None) -> None: - """Allocates memory blocks for storing the given sequence of token IDs. - - This method allocates the required number of blocks to store the given - sequence of token IDs. - - Args: - token_ids (List[int]): The sequence of token IDs to be stored. - device (Device, optional): The device on which the blocks should be - allocated. Defaults to Device.GPU. - extra_hash (Optional[int]): The hash value of additional - factors, such as adapters, that influence the block hash - in the prefixcaching block. - """ - assert not self._is_allocated - assert token_ids - blocks = self._allocate_blocks_for_token_ids(prev_block=None, - token_ids=token_ids, - device=device, - extra_hash=extra_hash) - self.update(blocks) - self._num_full_slots = len(token_ids) - - def update(self, blocks: List[Block]) -> None: - """Resets the table to the newly provided blocks - (with their corresponding block ids) - """ - self._blocks.update(blocks) - - def append_token_ids(self, - token_ids: List[int], - num_lookahead_slots: int = 0, - num_computed_slots: Optional[int] = None, - extra_hash: Optional[int] = None) -> None: - """Appends a sequence of token IDs to the existing blocks in the - BlockTable. - - This method appends the given sequence of token IDs to the existing - blocks in the BlockTable. If there is not enough space in the existing - blocks, new blocks are allocated using the `ensure_num_empty_slots` - method to accommodate the additional tokens. - - The token IDs are divided into chunks of size `block_size` (except for - the first chunk, which may be smaller), and each chunk is appended to a - separate block. - - Args: - token_ids (List[int]): The sequence of token IDs to be appended. - num_computed_slots (Optional[int]): The number of KV cache slots - that are already filled (computed). - When sliding window is enabled, this is used to compute how many - blocks to drop at the front of the sequence. - Without sliding window, None can be passed. - Without chunked prefill, it should be the same as - _num_full_slots. - extra_hash (Optional[int]): The hash value of additional - factors such as adapters that influence the block, apart - from the token_ids. - """ - assert self._is_allocated, "no blocks have been allocated" - assert len(self._blocks) > 0 - - # Drop blocks that are no longer needed due to sliding window - if self._max_block_sliding_window is not None: - null_block = self._allocator.allocate_or_get_null_block() - assert num_computed_slots is not None - end_block_idx = (num_computed_slots // - self._block_size) - self._max_block_sliding_window - for idx in range(0, end_block_idx): - b = self._blocks[idx] - if b is not null_block: - self._allocator.free(b) - self._blocks[idx] = null_block - - # Ensure there are enough empty slots for the new tokens plus - # lookahead slots - self.ensure_num_empty_slots(num_empty_slots=len(token_ids) + - num_lookahead_slots, - extra_hash=extra_hash) - - # Update the blocks with the new tokens - first_block_idx = self._num_full_slots // self._block_size - token_blocks = self._chunk_token_blocks_for_append(token_ids) - - for i, token_block in enumerate(token_blocks): - self._blocks.append_token_ids(first_block_idx + i, token_block) - - self._num_full_slots += len(token_ids) - - def ensure_num_empty_slots(self, - num_empty_slots: int, - extra_hash: Optional[int] = None) -> None: - """Ensures that the BlockTable has at least the specified number of - empty slots available. - - This method checks if the BlockTable has enough empty slots (i.e., - available space) to accommodate the requested number of tokens. If not, - it allocates additional blocks on the GPU to ensure that the required - number of empty slots is available. - - Args: - num_empty_slots (int): The minimum number of empty slots required. - extra_hash (Optional[int]): The hash value of additional - factors such as adapters that influence the block, apart - from the token_ids. - """ - # Currently the block table only supports - # appending tokens to GPU blocks. - device = Device.GPU - assert self._is_allocated - - if self._num_empty_slots >= num_empty_slots: - return - - slots_to_allocate = num_empty_slots - self._num_empty_slots - blocks_to_allocate = cdiv(slots_to_allocate, self._block_size) - - for _ in range(blocks_to_allocate): - assert len(self._blocks) > 0 - self._blocks.append( - self._allocator.allocate_mutable_block( - prev_block=self._blocks[-1], - device=device, - extra_hash=extra_hash)) - - def fork(self) -> "BlockTable": - """Creates a new BlockTable instance with a copy of the blocks from the - current instance. - - This method creates a new BlockTable instance with the same block size, - block allocator, and a copy of the blocks from the current instance. The - new BlockTable has its own independent set of blocks, but shares the - same underlying memory allocation with the original BlockTable. - - Returns: - BlockTable: A new BlockTable instance with a copy of the blocks from - the current instance. - """ - assert self._is_allocated - assert len(self._blocks) > 0 - forked_blocks = self._allocator.fork(self._blocks[-1]) - return BlockTable( - block_size=self._block_size, - block_allocator=self._allocator, - _blocks=forked_blocks, - max_block_sliding_window=self._max_block_sliding_window, - ) - - def free(self) -> None: - """Frees the memory occupied by the blocks in the BlockTable. - - This method iterates over all the blocks in the `_blocks` list and calls - the `free` method of the `_allocator` object to release the memory - occupied by each block. After freeing all the blocks, the `_blocks` list - is set to `None`. - """ - for block in self.blocks: - self._allocator.free(block) - self._blocks.reset() - - @property - def physical_block_ids(self) -> List[int]: - """Returns a list of physical block indices for the blocks in the - BlockTable. - - This property returns a list of integers, where each integer represents - the physical block index of a corresponding block in the `_blocks` list. - The physical block index is a unique identifier for the memory location - occupied by the block. - - Returns: - List[int]: A list of physical block indices for the blocks in the - BlockTable. - """ - return self._blocks.ids() - - def get_unseen_token_ids(self, sequence_token_ids: List[int]) -> List[int]: - """Get the number of "unseen" tokens in the sequence. - - Unseen tokens are tokens in the sequence corresponding to this block - table, but are not yet appended to this block table. - - Args: - sequence_token_ids (List[int]): The list of token ids in the - sequence. - - Returns: - List[int]: The postfix of sequence_token_ids that has not yet been - appended to the block table. - """ - - # Since the block table is append-only, the unseen token ids are the - # ones after the appended ones. - return sequence_token_ids[self.num_full_slots:] - - def _allocate_blocks_for_token_ids( - self, - prev_block: Optional[Block], - token_ids: List[int], - device: Device, - extra_hash: Optional[int] = None) -> List[Block]: - blocks: List[Block] = [] - - block_token_ids = [] - tail_token_ids = [] - for cur_token_ids in chunk_list(token_ids, self._block_size): - if len(cur_token_ids) == self._block_size: - block_token_ids.append(cur_token_ids) - else: - tail_token_ids.append(cur_token_ids) - - if block_token_ids: - blocks.extend( - self._allocator.allocate_immutable_blocks( - prev_block, - block_token_ids=block_token_ids, - device=device, - extra_hash=extra_hash)) - prev_block = blocks[-1] - - if tail_token_ids: - assert len(tail_token_ids) == 1 - cur_token_ids = tail_token_ids[0] - - block = self._allocator.allocate_mutable_block( - prev_block=prev_block, device=device, extra_hash=extra_hash) - block.append_token_ids(cur_token_ids) - - blocks.append(block) - - return blocks - - def _get_all_token_ids(self) -> List[int]: - # NOTE: This function is O(seq_len); use sparingly. - token_ids: List[int] = [] - - if not self._is_allocated: - return token_ids - - for block in self.blocks: - token_ids.extend(block.token_ids) - - return token_ids - - def _get_num_token_ids(self) -> int: - res = 0 - for block in self.blocks: - res += len(block.token_ids) - - return res - - @property - def _is_allocated(self) -> bool: - return len(self._blocks) > 0 - - @property - def blocks(self) -> List[Block]: - return self._blocks.list() - - @property - def _num_empty_slots(self) -> int: - assert self._is_allocated - return len(self._blocks) * self._block_size - self._num_full_slots - - @property - def num_full_slots(self) -> int: - """Returns the total number of tokens currently stored in the - BlockTable. - - Returns: - int: The total number of tokens currently stored in the BlockTable. - """ - return self._num_full_slots - - def get_num_blocks_touched_by_append_slots( - self, token_ids: List[int], num_lookahead_slots: int) -> int: - """Determine how many blocks will be "touched" by appending the token - ids. - - This is required for the scheduler to determine whether a sequence can - continue generation, or if it must be preempted. - """ - # Math below is equivalent to: - # all_token_ids = token_ids + [-1] * num_lookahead_slots - # token_blocks = self._chunk_token_blocks_for_append(all_token_ids) - # return len(token_blocks) - - num_token_ids = len(token_ids) + num_lookahead_slots - first_chunk_size = self._block_size - (self._num_full_slots % - self._block_size) - num_token_blocks = (1 + math.ceil( - (num_token_ids - first_chunk_size) / self._block_size)) - return num_token_blocks - - def _chunk_token_blocks_for_append( - self, token_ids: List[int]) -> List[List[int]]: - """Split the token ids into block-sized chunks so they can be easily - appended to blocks. The first such "token block" may have less token ids - than the block size, since the last allocated block may be partially - full. - - If no token ids are provided, then no chunks are returned. - """ - - if not token_ids: - return [] - - first_chunk_size = self._block_size - (self._num_full_slots % - self._block_size) - token_blocks = [token_ids[:first_chunk_size]] - token_blocks.extend( - chunk_list(token_ids[first_chunk_size:], self._block_size)) - return token_blocks diff --git a/vllm/core/block/cache_policy.py b/vllm/core/block/cache_policy.py new file mode 100644 index 0000000000000..be7f0f51923ec --- /dev/null +++ b/vllm/core/block/cache_policy.py @@ -0,0 +1,531 @@ +import math +from typing import List, Optional + +from vllm.core.block.common import PhysicalBlockTable, VirtualBlockTable +from vllm.core.block.interfaces import (Block, CachePolicy, + DeviceAwareBlockAllocator) +from vllm.utils import Device, cdiv, chunk_list + + +class CachePolicyBase(CachePolicy): + """This cache policy always allocates new blocks to append new tokens. + + Args: + block_size (int): The maximum number of tokens that can be stored in a + single block. + block_allocator (DeviceAwareBlockAllocator): The block allocator used to + manage memory for the blocks. + physical_block_table (Optional[List[Block]], optional): An optional list + of existing blocks to initialize the PhysicalBlockTable with. If not + provided, an empty PhysicalBlockTable is created. + + Attributes: + _block_size (int): The maximum number of tokens that can be stored in a + single block. + _allocator (DeviceAwareBlockAllocator): The block allocator used to + manage memory for the blocks. + _physical_block_table (PhysicalBlockTable): The list of blocks managed + by this PhysicalBlockTable. + """ + + def __init__( + self, + block_size: int, + block_allocator: DeviceAwareBlockAllocator, + physical_block_table: Optional[PhysicalBlockTable] = None, + virtual_block_table: Optional[VirtualBlockTable] = None, + ): + self._block_size = block_size + self._allocator = block_allocator + if physical_block_table is None: + physical_block_table = PhysicalBlockTable() + self._physical_block_table: PhysicalBlockTable = physical_block_table + if virtual_block_table is None: + virtual_block_table = VirtualBlockTable(block_size) + self._virtual_block_table: VirtualBlockTable = virtual_block_table + + def add_tokens_prefill(self, + token_ids: List[int], + device: Device = Device.GPU, + extra_hash: Optional[int] = None) -> None: + """Allocates memory blocks for storing the given sequence of token IDs + in prefill stage only. + + This method allocates the required number of blocks to store the given + sequence of token IDs. + + Args: + token_ids (List[int]): The sequence of token IDs to be stored. + device (Device, optional): The device on which the blocks should be + allocated. Defaults to Device.GPU. + extra_hash (Optional[int]): The hash value of additional + factors, such as adapters, that influence the block hash + in the prefix-caching block. + """ + assert not self._is_allocated + assert token_ids + token_chunks = chunk_list(token_ids, self._block_size) + blocks = self._allocate_blocks_for_token_ids(prev_block=None, + token_chunks=token_chunks, + device=device, + extra_hash=extra_hash) + self.update_blocks(blocks) + self._virtual_block_table.append_tokens(blocks, len(token_ids)) + + def update_blocks(self, blocks: List[Block]) -> None: + """Resets the table to the newly provided blocks + """ + self._physical_block_table.update(blocks) + + def add_tokens_decode(self, + token_ids: List[int], + num_lookahead_slots: int = 0, + extra_hash: Optional[int] = None) -> None: + """Add a sequence of token IDs to the existing blocks in the + PhysicalBlockTable. + + This method appends the given sequence of token IDs to the existing + blocks in the PhysicalBlockTable. If there is not enough space in the + existing blocks, new blocks are allocated using the + `_ensure_num_empty_slots` method to accommodate the additional tokens. + + The token IDs are divided into chunks of size `block_size` (except for + the first chunk, which may be smaller), and each chunk is appended to a + separate block. + + Args: + token_ids (List[int]): The sequence of token IDs to be appended. + num_lookahead_slots (int): The number of lookahead slots to allocate + in speculative decoding or chunked prefill. + extra_hash (Optional[int]): The hash value of additional + factors such as adapters that influence the block, apart + from the token_ids. + """ + assert self._is_allocated, "no blocks have been allocated" + assert self.num_physical_blocks > 0 + + # Ensure there are enough empty slots for the new tokens plus + # lookahead slots + self._ensure_num_empty_slots(num_empty_slots=len(token_ids) + + num_lookahead_slots, + extra_hash=extra_hash) + + # Update the blocks with the new tokens + first_block_idx = self.num_tokens // self._block_size + token_blocks = self._chunk_token_blocks(token_ids) + + for i, token_block in enumerate(token_blocks): + block = self._physical_block_table.append_tokens( + first_block_idx + i, token_block) + self._virtual_block_table.append_tokens([block], len(token_block)) + + def _ensure_num_empty_slots(self, + num_empty_slots: int, + extra_hash: Optional[int] = None) -> None: + """Ensures that the PhysicalBlockTable has at least the specified number + of empty slots available. + + This method checks if the PhysicalBlockTable has enough empty slots + (i.e., available space) to accommodate the requested number of tokens. + If not, it allocates additional blocks on the GPU to ensure that the + required number of empty slots is available. + + Args: + num_empty_slots (int): The minimum number of empty slots required. + extra_hash (Optional[int]): The hash value of additional + factors such as adapters that influence the block, apart + from the token_ids. + """ + # Currently the block table only supports + # appending tokens to GPU blocks. + device = Device.GPU + assert self._is_allocated + + if self._num_empty_slots >= num_empty_slots: + return + + slots_to_allocate = num_empty_slots - self._num_empty_slots + blocks_to_allocate = cdiv(slots_to_allocate, self._block_size) + + for _ in range(blocks_to_allocate): + assert self.num_physical_blocks > 0 + self._physical_block_table.append( + self._allocator.allocate_mutable_block( + prev_block=self._physical_block_table[-1], + device=device, + extra_hash=extra_hash)) + + def fork(self) -> "CachePolicy": + """Creates a new PhysicalBlockTable instance with a copy of the blocks + from the current instance. + + This method creates a new PhysicalBlockTable instance with the same + block size, block allocator, and a copy of the blocks from the current + instance. The new PhysicalBlockTable has its own independent set of + blocks, but shares the same underlying memory allocation with the + original PhysicalBlockTable. + + Returns: + PhysicalBlockTable: A new PhysicalBlockTable instance with a copy + of the blocks from the current instance. + """ + assert self._is_allocated + assert self.num_physical_blocks > 0 + physical_block_table = PhysicalBlockTable( + self._allocator.fork(self._physical_block_table[-1])) + virtual_block_table = self._virtual_block_table.fork() + return CachePolicyFactory.fork( + self, + physical_block_table=physical_block_table, + virtual_block_table=virtual_block_table) + + def free(self) -> None: + """Frees the memory occupied by the blocks in the PhysicalBlockTable. + + This method iterates over all the blocks in the `_physical_block_table` + list and calls the `free` method of the `_allocator` object to release + the memory occupied by each block. After freeing all the blocks, the + `_physical_block_table` list is set to `None`. + """ + for block in self.blocks: + self._allocator.free(block) + self._physical_block_table.reset() + self._virtual_block_table.reset() + + @property + def physical_block_ids(self) -> List[int]: + """Returns a list of physical block indices for the blocks in the + PhysicalBlockTable. + + This property returns a list of integers, where each integer represents + the physical block index of a corresponding block in the + `_physical_block_table` list. The physical block index is a unique + identifier for the memory location occupied by the block. + + Returns: + List[int]: A list of physical block indices for the blocks in the + PhysicalBlockTable. + """ + return self._physical_block_table.ids() + + @property + def num_physical_blocks(self) -> int: + return len(self._physical_block_table) + + @property + def slot_mappings(self) -> List[int]: + return self._virtual_block_table.slot_mappings + + @property + def num_tokens(self) -> int: + return self._virtual_block_table.num_tokens + + def get_unseen_token_ids(self, sequence_token_ids: List[int]) -> List[int]: + """Get the number of "unseen" tokens in the sequence. + + Unseen tokens are tokens in the sequence corresponding to this block + table, but are not yet appended to this block table. + + Args: + sequence_token_ids (List[int]): The list of token ids in the + sequence. + + Returns: + List[int]: The postfix of sequence_token_ids that has not yet been + appended to the block table. + """ + + # The token ids in sequence are append-only, the unseen token ids are + # the ones after the processed ones in block table. + return sequence_token_ids[self.num_full_slots:] + + def _allocate_blocks_for_token_ids( + self, + prev_block: Optional[Block], + token_chunks: List[List[int]], + device: Device, + extra_hash: Optional[int] = None) -> List[Block]: + blocks: List[Block] = [] + + block_token_ids = [] + tail_token_ids = [] + for cur_token_ids in token_chunks: + if len(cur_token_ids) == self._block_size: + block_token_ids.append(cur_token_ids) + else: + tail_token_ids.append(cur_token_ids) + + if block_token_ids: + blocks.extend( + self._allocator.allocate_immutable_blocks( + prev_block, + block_token_ids=block_token_ids, + device=device, + extra_hash=extra_hash)) + prev_block = blocks[-1] + + if tail_token_ids: + assert len(tail_token_ids) == 1 + cur_token_ids = tail_token_ids[0] + + block = self._allocator.allocate_mutable_block( + prev_block=prev_block, device=device, extra_hash=extra_hash) + block.append_token_ids(cur_token_ids) + + blocks.append(block) + + return blocks + + def _get_all_token_ids(self) -> List[int]: + # NOTE: This function is O(seq_len); use sparingly. + token_ids: List[int] = [] + + if not self._is_allocated: + return token_ids + + for block in self.blocks: + token_ids.extend(block.token_ids) + + return token_ids + + @property + def _is_allocated(self) -> bool: + return len(self._physical_block_table) > 0 + + @property + def blocks(self) -> List[Block]: + return self._physical_block_table.list() + + @property + def _num_empty_slots(self) -> int: + assert self._is_allocated + return len( + self._physical_block_table) * self._block_size - self.num_tokens + + @property + def num_full_slots(self) -> int: + """Returns the total number of tokens currently stored in the + PhysicalBlockTable. + + Returns: + int: The total number of tokens currently stored in the + PhysicalBlockTable. + """ + return self.num_tokens + + def get_num_blocks_touched_by_append_slots( + self, token_ids: List[int], num_lookahead_slots: int) -> int: + """Determine how many blocks will be "touched" by appending the token + ids. + + This is required for the scheduler to determine whether a sequence can + continue generation, or if it must be preempted. + """ + # Math below is equivalent to: + # all_token_ids = token_ids + [-1] * num_lookahead_slots + # token_blocks = self._chunk_token_blocks(all_token_ids) + # return len(token_blocks) + + num_token_ids = len(token_ids) + num_lookahead_slots + first_chunk_size = self._block_size - (self.num_tokens % + self._block_size) + num_token_blocks = (1 + math.ceil( + (num_token_ids - first_chunk_size) / self._block_size)) + return num_token_blocks + + def _chunk_token_blocks(self, token_ids: List[int]) -> List[List[int]]: + """Split the token ids into block-sized chunks so they can be easily + appended to blocks. The first such "token block" may have less token ids + than the block size, since the last allocated block may be partially + full. + + If no token ids are provided, then no chunks are returned. + """ + + if not token_ids: + return [] + + first_chunk_size = self._block_size - (self.num_tokens % + self._block_size) + token_blocks = [token_ids[:first_chunk_size]] + token_blocks.extend( + chunk_list(token_ids[first_chunk_size:], self._block_size)) + return token_blocks + + +class CachePolicySlidingWindow(CachePolicyBase): + """This cache policy has a sliding-window context and a fixed cache space + as a result. + + Args: + num_sliding_window_blocks (int): The number of blocks to keep around + for a sequence. It should at least fit the sliding window size of + the context. + """ + + def __init__( + self, + block_size: int, + block_allocator: DeviceAwareBlockAllocator, + physical_block_table: Optional[PhysicalBlockTable] = None, + virtual_block_table: Optional[VirtualBlockTable] = None, + num_sliding_window_blocks: Optional[int] = None, + ): + super().__init__(block_size, block_allocator, physical_block_table, + virtual_block_table) + assert num_sliding_window_blocks is not None + self._num_sliding_window_blocks = num_sliding_window_blocks + + def add_tokens_prefill(self, + token_ids: List[int], + device: Device = Device.GPU, + extra_hash: Optional[int] = None) -> None: + """Allocate memory blocks for storing the given sequence of token IDs + in prefill stage only. + + This method allocates the required number of blocks to store the given + sequence of token IDs only inside the sliding window, or + _num_sliding_window_blocks to be exact. + """ + assert not self._is_allocated + assert token_ids + + block_start_idx = 0 + token_chunks = list(chunk_list(token_ids, self._block_size)) + num_evicted_chunks = len( + token_chunks) - self._num_sliding_window_blocks + + if num_evicted_chunks > 0: + num_windows = len(token_chunks) // self._num_sliding_window_blocks + last_window_end = num_windows * self._num_sliding_window_blocks + last_window_start = (last_window_end - + self._num_sliding_window_blocks) + last_window_chunks = token_chunks[ + last_window_start:last_window_end] + remainder_chunks = token_chunks[last_window_end:] + + # The remainder chunks cannot fill up a window, we need to rotate + # these chunks back to the front of the last full sliding window. + if len(remainder_chunks) > 0: + chunk_idx = len(remainder_chunks[:-1]) + last_window_chunks[:chunk_idx] = remainder_chunks[:chunk_idx] + last_window_chunks[ + chunk_idx][:len(remainder_chunks[chunk_idx])] = ( + remainder_chunks[chunk_idx]) + block_start_idx = chunk_idx + if len(remainder_chunks[chunk_idx]) == self._block_size: + block_start_idx = chunk_idx + 1 + token_chunks = last_window_chunks + + blocks = self._allocate_blocks_for_token_ids(prev_block=None, + token_chunks=token_chunks, + device=device, + extra_hash=extra_hash) + self.update_blocks(blocks) + + num_evicted_tokens = 0 + if num_evicted_chunks > 0: + # Chronologically, we maintain the chunk order in the sequence to + # be added into the block tables. + blocks = blocks[block_start_idx:] + blocks[:block_start_idx] + + # Allocate null blocks to represent the evicted tokens. + null_block = self._allocator.allocate_or_get_null_block() + evicted_blocks = [null_block] * num_evicted_chunks + num_evicted_tokens = ( + len(token_ids) - + self._num_sliding_window_blocks * self._block_size) + self._virtual_block_table.append_tokens(evicted_blocks, + num_evicted_tokens, + evicted=True) + # Partially filled block actually appears twice due to rotation. + if num_evicted_tokens % self._block_size != 0: + blocks.append(blocks[0]) + + self._virtual_block_table.append_tokens( + blocks, len(token_ids[num_evicted_tokens:])) + + def add_tokens_decode(self, + token_ids: List[int], + num_lookahead_slots: int = 0, + extra_hash: Optional[int] = None) -> None: + """Add a sequence of token IDs to the blocks in the PhysicalBlockTable + by rotating the blocks when appending new tokens when the sliding window + is full. This means the currently oldest tokens are evicted and replaced + with new tokens. + + """ + assert self._is_allocated, "no blocks have been allocated" + assert self.num_physical_blocks > 0 + + # Rotate and reuse blocks beyond sliding window so that no new blocks + # are needed + assert self.num_physical_blocks <= self._num_sliding_window_blocks + if self.num_physical_blocks < self._num_sliding_window_blocks: + # Ensure there are enough empty slots for the new tokens plus + # lookahead slots + self._ensure_num_empty_slots(num_empty_slots=len(token_ids) + + num_lookahead_slots, + extra_hash=extra_hash) + + # Update the blocks with the new tokens + first_block_idx = (self.num_tokens // self._block_size % + self.num_physical_blocks) + token_blocks = self._chunk_token_blocks(token_ids) + + slot_offsets = [0] * len(token_blocks) + slot_offsets[0] = self.num_tokens % self._block_size + for i, (slot_offset, + token_block) in enumerate(zip(slot_offsets, token_blocks), + start=first_block_idx): + i %= self.num_physical_blocks + block = self._physical_block_table.insert_tokens( + i, slot_offset, token_block) + self._virtual_block_table.insert_tokens(block, slot_offset, + len(token_block)) + + +class CachePolicyFactory: + + @staticmethod + def create( + num_sliding_window_blocks: Optional[int], + block_size: int, + block_allocator: DeviceAwareBlockAllocator, + physical_block_table: Optional[PhysicalBlockTable] = None, + virtual_block_table: Optional[VirtualBlockTable] = None, + ) -> "CachePolicy": + if num_sliding_window_blocks is None: + return CachePolicyBase( + block_size=block_size, + block_allocator=block_allocator, + physical_block_table=physical_block_table, + virtual_block_table=virtual_block_table, + ) + else: + return CachePolicySlidingWindow( + block_size=block_size, + block_allocator=block_allocator, + num_sliding_window_blocks=num_sliding_window_blocks, + physical_block_table=physical_block_table, + virtual_block_table=virtual_block_table, + ) + + @staticmethod + def fork( + instance: CachePolicy, + physical_block_table: PhysicalBlockTable, + virtual_block_table: VirtualBlockTable, + ) -> "CachePolicy": + if hasattr(instance, "_num_sliding_window_blocks"): + num_sliding_window_blocks = instance._num_sliding_window_blocks # type: ignore + else: + num_sliding_window_blocks = None + + return CachePolicyFactory.create( + block_size=instance._block_size, # type: ignore + block_allocator=instance._allocator, # type: ignore + num_sliding_window_blocks=num_sliding_window_blocks, + physical_block_table=physical_block_table, + virtual_block_table=virtual_block_table, + ) diff --git a/vllm/core/block/common.py b/vllm/core/block/common.py index c03b5932eafb6..ac45e3509bed1 100644 --- a/vllm/core/block/common.py +++ b/vllm/core/block/common.py @@ -1,11 +1,16 @@ -from collections import deque +import copy +from collections import defaultdict, deque from dataclasses import dataclass -from typing import Deque, Dict, Iterable, List, Optional, Protocol, Tuple +from typing import (DefaultDict, Deque, Dict, Iterable, List, Optional, + Protocol, Tuple) from vllm.core.block.interfaces import Block, BlockAllocator BlockId = int RefCount = int +SlotMappings = List[int] +TokenMappings = DefaultDict[int, List[int]] +EVICTED_SLOT_ID = -2 class RefCounterProtocol(Protocol): @@ -34,9 +39,10 @@ class RefCounter(RefCounterProtocol): def __init__(self, all_block_indices: Iterable[BlockId]): deduped = set(all_block_indices) - self._refcounts: Dict[BlockId, - RefCount] = {index: 0 - for index in deduped} + self._refcounts: Dict[BlockId, RefCount] = { + index: 0 + for index in deduped + } def incr(self, block_id: BlockId) -> RefCount: assert block_id in self._refcounts @@ -225,37 +231,133 @@ def free_block(self, block: Block) -> None: self._free_ids.appendleft(block.pool_id) # type: ignore[attr-defined] -class BlockList: - """This class is an optimization to allow fast-access to physical - block ids. It maintains a block id list that is updated with the - block list and this avoids the need to reconstruct the block id - list on every iteration of the block manager +class VirtualBlockTable: + """VirtualBlockTable maintains the mappings between tokens and physical + memory blocks. Both the token mappings and slot mappings are tracked + respectively, one maps physical blocks to tokens, and the other maps + tokens to their memory slots. The slot mappings are updated with the tokens + added and this avoids the need to reconstruct the mappings on every + iteration for kv copy to the cache. """ - def __init__(self, blocks: List[Block]): + def __init__(self, + block_size: int, + slot_mappings: Optional[SlotMappings] = None, + token_mappings: Optional[TokenMappings] = None): + self._block_size = block_size + if slot_mappings is None: + slot_mappings = [] + self._slot_mappings: SlotMappings = slot_mappings + if token_mappings is None: + token_mappings = ( + defaultdict(lambda: [EVICTED_SLOT_ID] * self._block_size)) + self._token_mappings: TokenMappings = token_mappings + + def append_tokens(self, + blocks: List[Block], + num_new_tokens: int, + evicted: bool = False) -> None: + first_chunk_size = self._block_size - (self.num_tokens % + self._block_size) + if first_chunk_size < num_new_tokens: + last_chunk_size = (self.num_tokens + + num_new_tokens) % self._block_size + num_middle_chunks = ((num_new_tokens - first_chunk_size) // + self._block_size) + middle_chunk_sizes = [(0, self._block_size)] * num_middle_chunks + chunk_list = [(self._block_size - first_chunk_size, + first_chunk_size)] + chunk_list.extend(middle_chunk_sizes) + if last_chunk_size > 0: + chunk_list.append((0, last_chunk_size)) + else: + chunk_list = [(self._block_size - first_chunk_size, num_new_tokens) + ] + assert len(chunk_list) == len(blocks) + + for (slot_offset, chunk_size), block in zip(chunk_list, blocks): + if not evicted: + block_token_mappings = self._token_mappings[block.block_id] + block_token_mappings[slot_offset:slot_offset + + chunk_size] = (range( + self.num_tokens, + self.num_tokens + chunk_size)) + + slot_start = block.block_id * self._block_size + slot_offset + slot_end = slot_start + chunk_size + slots = range(slot_start, slot_end) + else: + slots = [EVICTED_SLOT_ID] * chunk_size + self._slot_mappings.extend(slots) + + def insert_tokens(self, block: Block, slot_offset: int, + num_new_tokens: int) -> None: + # If evicting previous tokens from physical blocks, replace them with + # new tokens, and update the slot_mappings + block_slot_start = slot_offset + block_slot_end = slot_offset + num_new_tokens + block_token_mappings = self._token_mappings[block.block_id] + for token_idx in block_token_mappings[block_slot_start:block_slot_end]: + if token_idx != EVICTED_SLOT_ID: + self._slot_mappings[token_idx] = EVICTED_SLOT_ID + + # Populate new tokens in the physical blocks, and update both + # token_mappings and slot_mappings + slot_start = (block.block_id * self._block_size) + slot_offset + slot_end = slot_start + num_new_tokens + self._slot_mappings.extend(range(slot_start, slot_end)) + + token_start = self.num_tokens - num_new_tokens + block_token_mappings[block_slot_start:block_slot_end] = range( + token_start, self.num_tokens) + + def reset(self): + self._slot_mappings = [] + self._token_mappings = defaultdict( + lambda: [EVICTED_SLOT_ID] * self._block_size) + + @property + def num_tokens(self) -> int: + return len(self._slot_mappings) + + @property + def slot_mappings(self) -> SlotMappings: + return self._slot_mappings + + @property + def token_mappings(self) -> TokenMappings: + return self._token_mappings + + def fork(self) -> "VirtualBlockTable": + return VirtualBlockTable(self._block_size, self._slot_mappings.copy(), + copy.deepcopy(self._token_mappings)) + + +class PhysicalBlockTable: + """PhysicalBlockTable (formerly BlockList) keeps track of the allocated + cache blocks. It is also an optimization to allow fast-access to physical + block ids. It maintains a block id list that is updated with the block + list and this avoids the need to reconstruct the block id list on every + iteration of the block manager. + """ + + def __init__(self, blocks: Optional[List[Block]] = None): + if blocks is None: + blocks = [] self._blocks: List[Block] = [] self._block_ids: List[int] = [] self.update(blocks) - def _add_block_id(self, block_id: Optional[BlockId]) -> None: - assert block_id is not None - self._block_ids.append(block_id) - - def _update_block_id(self, block_index: int, - new_block_id: Optional[BlockId]) -> None: - assert new_block_id is not None - self._block_ids[block_index] = new_block_id - def update(self, blocks: List[Block]): - self._blocks = blocks - + self._blocks = blocks.copy() # Cache block ids for fast query self._block_ids = [] + for block in self._blocks: self._add_block_id(block.block_id) - def append_token_ids(self, block_index: int, token_ids: List[int]) -> None: + def append_tokens(self, block_index: int, token_ids: List[int]) -> Block: block = self._blocks[block_index] prev_block_id = block.block_id @@ -265,10 +367,35 @@ def append_token_ids(self, block_index: int, token_ids: List[int]) -> None: if prev_block_id != block.block_id: self._update_block_id(block_index, block.block_id) + return block + + def insert_tokens(self, block_index: int, slot_offset: int, + token_ids: List[int]) -> Block: + block = self._blocks[block_index] + prev_block_id = block.block_id + + block.insert_token_ids(slot_offset, token_ids) + + # CoW or promotion may update the internal block_id + if prev_block_id != block.block_id: + self._update_block_id(block_index, block.block_id) + + return block + def append(self, new_block: Block): self._blocks.append(new_block) self._add_block_id(new_block.block_id) + def reset(self): + self._blocks = [] + self._block_ids = [] + + def list(self) -> List[Block]: + return self._blocks + + def ids(self) -> List[int]: + return self._block_ids + def __len__(self) -> int: return len(self._blocks) @@ -279,15 +406,14 @@ def __setitem__(self, block_index: int, new_block: Block) -> None: self._blocks[block_index] = new_block self._update_block_id(block_index, new_block.block_id) - def reset(self): - self._blocks = [] - self._block_ids = [] - - def list(self) -> List[Block]: - return self._blocks + def _add_block_id(self, block_id: BlockId) -> None: + assert block_id is not None + self._block_ids.append(block_id) - def ids(self) -> List[int]: - return self._block_ids + def _update_block_id(self, block_index: int, + new_block_id: Optional[BlockId]) -> None: + assert new_block_id is not None + self._block_ids[block_index] = new_block_id @dataclass diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py index 3a57487a6cd8a..663e56fdbf585 100644 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -375,6 +375,9 @@ def __init__(self, proxy: Block): def append_token_ids(self, token_ids: List[BlockId]): raise ValueError("null block should not be modified") + def insert_token_ids(self, slot_offset: int, token_ids: List[int]) -> None: + raise ValueError("null block should not be modified") + @property def block_id(self): return self._proxy.block_id diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py index 985a1098b6cd1..2fa2b221bfcb2 100644 --- a/vllm/core/block/interfaces.py +++ b/vllm/core/block/interfaces.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Dict, FrozenSet, List, Optional, Protocol, Tuple -from vllm.utils import Device +from vllm.utils import Device, cdiv BlockId = int @@ -12,6 +12,10 @@ class Block(ABC): def append_token_ids(self, token_ids: List[int]) -> None: pass + @abstractmethod + def insert_token_ids(self, slot_offset: int, token_ids: List[int]) -> None: + pass + @property @abstractmethod def block_id(self) -> Optional[int]: @@ -304,3 +308,104 @@ def find_cached_blocks_prefix( device: Device = Device.GPU, ) -> List[int]: pass + + +class CachePolicy(ABC): + """A class to manage use of blocks for a sequence + + The PhysicalBlockTable maps a sequence of tokens to a list of blocks, where + each block represents a contiguous memory allocation for a portion of the + sequence. The VirtualBlockTable maps each memory slot in each block to the + tokens and vice versa. Note that it is possible for a token to be evicted + from blocks to save cache space. The implementations bear the responsibility + of managing the blocks and tokens in terms of cache use while maintaining + the PhysicalBlockTable and the VirtualTable. The blocks are managed by a + DeviceAwareBlockAllocator, which is responsible for allocating and freeing + memory for the blocks. + + """ + + @abstractmethod + def add_tokens_prefill(self, + token_ids: List[int], + device: Device = Device.GPU, + extra_hash: Optional[int] = None) -> None: + pass + + @abstractmethod + def update_blocks(self, + blocks: List[Block]) -> None: + pass + + @abstractmethod + def add_tokens_decode(self, + token_ids: List[int], + num_lookahead_slots: int = 0, + extra_hash: Optional[int] = None) -> None: + pass + + @abstractmethod + def fork(self) -> "CachePolicy": + pass + + @abstractmethod + def free(self) -> None: + pass + + @property + @abstractmethod + def physical_block_ids(self) -> List[int]: + pass + + @property + @abstractmethod + def slot_mappings(self) -> List[int]: + pass + + @property + @abstractmethod + def num_tokens(self) -> int: + pass + + @abstractmethod + def get_unseen_token_ids(self, sequence_token_ids: List[int]) -> List[int]: + pass + + @property + @abstractmethod + def blocks(self) -> List[Block]: + pass + + @property + @abstractmethod + def num_full_slots(self) -> int: + pass + + @abstractmethod + def get_num_blocks_touched_by_append_slots( + self, token_ids: List[int], num_lookahead_slots: int) -> int: + pass + + @staticmethod + def get_num_required_blocks(token_ids: List[int], + block_size: int, + num_lookahead_slots: int = 0) -> int: + """Calculates the minimum number of blocks required to store a given + sequence of token IDs along with any look-ahead slots that may be + required (like in multistep + chunked-prefill). + + This assumes worst-case scenario, where every block requires a new + allocation (e.g. ignoring prefix caching). + + Args: + token_ids (List[int]): The sequence of token IDs to be stored. + block_size (int): The maximum number of tokens that can be stored in + a single block. + num_lookahead_slots (int): look-ahead slots that the sequence may + require. + + Returns: + int: The minimum number of blocks required to store the given + sequence of token IDs along with any required look-ahead slots. + """ + return cdiv(len(token_ids) + num_lookahead_slots, block_size) diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py index 9b94918ab38ef..a87a4312b4150 100644 --- a/vllm/core/block/naive_block.py +++ b/vllm/core/block/naive_block.py @@ -396,6 +396,27 @@ def _append_token_ids_no_cow(self, token_ids: List[int]) -> None: self._token_ids.extend(token_ids) + def insert_token_ids(self, slot_start: int, token_ids: List[int]) -> None: + """Inserts the given token IDs to the block and performs a + copy-on-write if necessary. + + Args: + slot_start: (int): The start address of first token to be inserted + token_ids (Optional[List[int]]): The token IDs to be appended + to the block. + """ + if len(token_ids) == 0: + return + + slot_end = slot_start + len(token_ids) + assert slot_end <= self._block_size + + self._token_ids[slot_start:slot_end] = token_ids + + if self._block_id is not None: + self._block_id = (self._allocator.cow_block_if_not_appendable( + self._cow_target)) + @property def computed(self) -> bool: raise NotImplementedError diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index 1238303234deb..9d5a2ac759a93 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -794,6 +794,9 @@ def append_token_ids(self, token_ids: List[int]) -> None: if self.content_hash is not None: self.block_id = self._allocator.promote_to_immutable_block(self) + def insert_token_ids(self, slot_start: int, token_ids: List[int]) -> None: + raise ValueError("Block should not be modified in prefix caching") + @property def block_id(self) -> Optional[int]: return self._block.block_id diff --git a/vllm/core/block/utils.py b/vllm/core/block/utils.py index 1c6578e4cc6ab..73611f790f3e9 100644 --- a/vllm/core/block/utils.py +++ b/vllm/core/block/utils.py @@ -19,7 +19,7 @@ def check_no_caching_or_swa_for_blockmgr_encdec( ''' if seq_group.is_encoder_decoder(): - if block_mgr.max_block_sliding_window is not None: + if block_mgr.num_sliding_window_blocks is not None: raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_SWA) if block_mgr.enable_caching: diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index b41e848221882..c75d8dca0760e 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -3,9 +3,9 @@ from typing import Sequence as GenericSequence from typing import Tuple -from vllm.core.block.block_table import BlockTable +from vllm.core.block.cache_policy import CachePolicyFactory from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator -from vllm.core.block.interfaces import Block +from vllm.core.block.interfaces import Block, CachePolicy from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker, LastAccessBlocksTracker) from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec @@ -71,17 +71,12 @@ def __init__( self.num_total_cpu_blocks = num_cpu_blocks self.sliding_window = sliding_window - # max_block_sliding_window is the max number of blocks that need to be - # allocated - self.max_block_sliding_window = None + # num_sliding_window_blocks is the max number of blocks that need to be + # allocated if the sliding window is enabled. + self.num_sliding_window_blocks = None if sliding_window is not None: - # +1 here because // rounds down - num_blocks = sliding_window // block_size + 1 - # +1 here because the last block may not be full, - # and so the sequence stretches one more block at the beginning - # For example, if sliding_window is 3 and block_size is 4, - # we may need 2 blocks when the second block only holds 1 token. - self.max_block_sliding_window = num_blocks + 1 + self.num_sliding_window_blocks = ( + (sliding_window + block_size - 1) // block_size) self.watermark = watermark assert watermark >= 0.0 @@ -97,8 +92,8 @@ def __init__( block_size=block_size, ) - self.block_tables: Dict[SeqId, BlockTable] = {} - self.cross_block_tables: Dict[EncoderSeqId, BlockTable] = {} + self.cache_policies: Dict[SeqId, CachePolicy] = {} + self.cross_cache_policies: Dict[EncoderSeqId, CachePolicy] = {} self._computed_blocks_tracker = ComputedBlocksTracker( self.block_allocator, self.block_size, self.enable_caching) @@ -114,7 +109,7 @@ def can_allocate(self, check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] - num_required_blocks = BlockTable.get_num_required_blocks( + num_required_blocks = CachePolicy.get_num_required_blocks( seq.get_token_ids(), block_size=self.block_size, num_lookahead_slots=num_lookahead_slots, @@ -123,63 +118,62 @@ def can_allocate(self, if seq_group.is_encoder_decoder(): encoder_seq = seq_group.get_encoder_seq() assert encoder_seq is not None - num_required_blocks += BlockTable.get_num_required_blocks( + num_required_blocks += CachePolicy.get_num_required_blocks( encoder_seq.get_token_ids(), block_size=self.block_size, ) - if self.max_block_sliding_window is not None: + if self.num_sliding_window_blocks is not None: num_required_blocks = min(num_required_blocks, - self.max_block_sliding_window) + self.num_sliding_window_blocks) num_free_gpu_blocks = self.block_allocator.get_num_free_blocks( device=Device.GPU) # Use watermark to avoid frequent cache eviction. - if (self.num_total_gpu_blocks - num_required_blocks < - self.watermark_blocks): + if (self.num_total_gpu_blocks - num_required_blocks + < self.watermark_blocks): return AllocStatus.NEVER if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks: return AllocStatus.OK else: return AllocStatus.LATER - def _allocate_sequence(self, seq: Sequence) -> BlockTable: - block_table = BlockTable( + def _allocate_sequence(self, seq: Sequence) -> "CachePolicy": + cache_policy = CachePolicyFactory.create( block_size=self.block_size, block_allocator=self.block_allocator, - max_block_sliding_window=self.max_block_sliding_window, - ) + num_sliding_window_blocks=self.num_sliding_window_blocks) if seq.get_token_ids(): # NOTE: If there are any factors affecting the block besides # token_ids, they should be added as input to extra_hash. extra_hash = seq.extra_hash() # Add blocks to the block table only if the sequence is non empty. - block_table.allocate(token_ids=seq.get_token_ids(), - extra_hash=extra_hash) + cache_policy.add_tokens_prefill(token_ids=seq.get_token_ids(), + extra_hash=extra_hash) - return block_table + return cache_policy def allocate(self, seq_group: SequenceGroup) -> None: # Allocate self-attention block tables for decoder sequences waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) assert not (set(seq.seq_id for seq in waiting_seqs) - & self.block_tables.keys()), "block table already exists" + & self.cache_policies.keys()), "block table already exists" # NOTE: Here we assume that all sequences in the group have the same # prompt. seq = waiting_seqs[0] - block_table: BlockTable = self._allocate_sequence(seq) - self.block_tables[seq.seq_id] = block_table + cache_policy: CachePolicy = self._allocate_sequence(seq) + self.cache_policies[seq.seq_id] = cache_policy # Track seq self._last_access_blocks_tracker.add_seq(seq.seq_id) # Assign the block table for each sequence. for seq in waiting_seqs[1:]: - self.block_tables[seq.seq_id] = block_table.fork() + self.cache_policies[seq.seq_id] = cache_policy.fork() # Track seq self._last_access_blocks_tracker.add_seq(seq.seq_id) @@ -191,7 +185,7 @@ def allocate(self, seq_group: SequenceGroup) -> None: request_id = seq_group.request_id assert (request_id - not in self.cross_block_tables), \ + not in self.cross_cache_policies), \ "block table already exists" check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) @@ -199,11 +193,11 @@ def allocate(self, seq_group: SequenceGroup) -> None: if seq_group.is_encoder_decoder(): encoder_seq = seq_group.get_encoder_seq() assert encoder_seq is not None - block_table = self._allocate_sequence(encoder_seq) - self.cross_block_tables[request_id] = block_table + cache_policy = self._allocate_sequence(encoder_seq) + self.cross_cache_policies[request_id] = cache_policy - def can_append_slots(self, seq_group: SequenceGroup, - num_lookahead_slots: int) -> bool: + def can_add_slots(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> bool: """Determine if there is enough space in the GPU KV cache to continue generation of the specified sequence group. @@ -214,15 +208,17 @@ def can_append_slots(self, seq_group: SequenceGroup, "Lookahead slots" are slots that are allocated in addition to the slots for known tokens. The contents of the lookahead slots are not defined. This is used by speculative decoding when speculating future tokens. + + TODO(Shawnd200): different cache policies have different space needs. """ num_touched_blocks = 0 for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): - block_table = self.block_tables[seq.seq_id] + cache_policy = self.cache_policies[seq.seq_id] num_touched_blocks += ( - block_table.get_num_blocks_touched_by_append_slots( - token_ids=block_table.get_unseen_token_ids( + cache_policy.get_num_blocks_touched_by_append_slots( + token_ids=cache_policy.get_unseen_token_ids( seq.get_token_ids()), num_lookahead_slots=num_lookahead_slots, )) @@ -231,20 +227,18 @@ def can_append_slots(self, seq_group: SequenceGroup, Device.GPU) return num_touched_blocks <= num_free_gpu_blocks - def append_slots( + def add_slots( self, seq: Sequence, num_lookahead_slots: int, ) -> List[Tuple[int, int]]: - block_table = self.block_tables[seq.seq_id] + cache_policy = self.cache_policies[seq.seq_id] - block_table.append_token_ids( - token_ids=block_table.get_unseen_token_ids(seq.get_token_ids()), + cache_policy.add_tokens_decode( + token_ids=cache_policy.get_unseen_token_ids(seq.get_token_ids()), num_lookahead_slots=num_lookahead_slots, - num_computed_slots=seq.data.get_num_computed_tokens(), - extra_hash=seq.extra_hash(), - ) + extra_hash=seq.extra_hash()) # Return any new copy-on-writes. new_cows = self.block_allocator.clear_copy_on_writes() return new_cows @@ -252,41 +246,51 @@ def append_slots( def free(self, seq: Sequence) -> None: seq_id = seq.seq_id - if seq_id not in self.block_tables: + if seq_id not in self.cache_policies: # Already freed or haven't been scheduled yet. return # Update seq block ids with the latest access time self._last_access_blocks_tracker.update_seq_blocks_last_access( - seq_id, self.block_tables[seq.seq_id].physical_block_ids) + seq_id, self.cache_policies[seq.seq_id].physical_block_ids) # Untrack seq self._last_access_blocks_tracker.remove_seq(seq_id) self._computed_blocks_tracker.remove_seq(seq_id) # Free table/blocks - self.block_tables[seq_id].free() - del self.block_tables[seq_id] + self.cache_policies[seq_id].free() + del self.cache_policies[seq_id] def free_cross(self, seq_group: SequenceGroup) -> None: request_id = seq_group.request_id - if request_id not in self.cross_block_tables: + if request_id not in self.cross_cache_policies: # Already freed or hasn't been scheduled yet. return - self.cross_block_tables[request_id].free() - del self.cross_block_tables[request_id] + self.cross_cache_policies[request_id].free() + del self.cross_cache_policies[request_id] def get_block_table(self, seq: Sequence) -> List[int]: - block_ids = self.block_tables[seq.seq_id].physical_block_ids + block_ids = self.cache_policies[seq.seq_id].physical_block_ids return block_ids # type: ignore + def get_slot_mapping(self, seq: Sequence) -> List[int]: + seq_slot_mappings = self.cache_policies[seq.seq_id].slot_mappings + return seq_slot_mappings # type: ignore + def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]: request_id = seq_group.request_id - assert request_id in self.cross_block_tables - block_ids = self.cross_block_tables[request_id].physical_block_ids + assert request_id in self.cross_cache_policies + block_ids = self.cross_cache_policies[request_id].physical_block_ids assert all(b is not None for b in block_ids) return block_ids # type: ignore + def get_cross_slot_mapping(self, seq_group: SequenceGroup) -> List[int]: + request_id = seq_group.request_id + assert request_id in self.cross_cache_policies + seq_slot_mappings = self.cross_cache_policies[request_id].slot_mappings + return seq_slot_mappings # type: ignore + def access_all_blocks_in_seq(self, seq: Sequence, now: float): if self.enable_caching: # Record the latest access time for the sequence. The actual update @@ -318,7 +322,7 @@ def get_common_computed_block_ids( """ computed_seq_block_ids = [] for seq in seqs: - all_blocks = self.block_tables[seq.seq_id].physical_block_ids + all_blocks = self.cache_policies[seq.seq_id].physical_block_ids num_cached_tokens = ( self._computed_blocks_tracker.get_num_cached_tokens(seq)) assert num_cached_tokens % self.block_size == 0 @@ -331,11 +335,11 @@ def get_common_computed_block_ids( computed_seq_block_ids) # type: ignore def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: - if parent_seq.seq_id not in self.block_tables: + if parent_seq.seq_id not in self.cache_policies: # Parent sequence has either been freed or never existed. return - src_block_table = self.block_tables[parent_seq.seq_id] - self.block_tables[child_seq.seq_id] = src_block_table.fork() + src_block_table = self.cache_policies[parent_seq.seq_id] + self.cache_policies[child_seq.seq_id] = src_block_table.fork() # Track child seq self._last_access_blocks_tracker.add_seq(child_seq.seq_id) @@ -369,7 +373,7 @@ def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: """ physical_block_id_mapping = [] for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): - blocks = self.block_tables[seq.seq_id].blocks + blocks = self.cache_policies[seq.seq_id].blocks if len(blocks) == 0: continue @@ -378,7 +382,7 @@ def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: dst_device=Device.GPU) # Refresh the block ids of the table (post-swap) - self.block_tables[seq.seq_id].update(blocks) + self.cache_policies[seq.seq_id].update_blocks(blocks) seq_physical_block_id_mapping = { self.block_allocator.get_physical_block_id( @@ -422,7 +426,7 @@ def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: """ physical_block_id_mapping = [] for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): - blocks = self.block_tables[seq.seq_id].blocks + blocks = self.cache_policies[seq.seq_id].blocks if len(blocks) == 0: continue @@ -431,7 +435,7 @@ def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: dst_device=Device.CPU) # Refresh the block ids of the table (post-swap) - self.block_tables[seq.seq_id].update(blocks) + self.cache_policies[seq.seq_id].update_blocks(blocks) seq_physical_block_id_mapping = { self.block_allocator.get_physical_block_id( @@ -481,16 +485,16 @@ def _can_swap(self, num_blocks_touched = 0 blocks: List[Block] = [] for seq in seq_group.get_seqs(status=status): - block_table = self.block_tables[seq.seq_id] - if block_table.blocks is not None: + cache_policy = self.cache_policies[seq.seq_id] + if cache_policy.blocks is not None: # Compute the number blocks to touch for the tokens to be # appended. This does NOT include the full blocks that need # to be touched for the swap. num_blocks_touched += \ - block_table.get_num_blocks_touched_by_append_slots( - block_table.get_unseen_token_ids(seq.get_token_ids()), + cache_policy.get_num_blocks_touched_by_append_slots( + cache_policy.get_unseen_token_ids(seq.get_token_ids()), num_lookahead_slots=num_lookahead_slots) - blocks.extend(block_table.blocks) + blocks.extend(cache_policy.blocks) # Compute the number of full blocks to touch and add it to the # existing count of blocks to touch. num_blocks_touched += self.block_allocator.get_num_full_blocks_touched( diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index b10b8d3f4a5bf..50a14cf1fbfb8 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -50,12 +50,12 @@ def allocate(self, seq_group: SequenceGroup) -> None: pass @abstractmethod - def can_append_slots(self, seq_group: SequenceGroup, - num_lookahead_slots: int) -> bool: + def can_add_slots(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> bool: pass @abstractmethod - def append_slots( + def add_slots( self, seq: Sequence, num_lookahead_slots: int, diff --git a/vllm/core/placeholder_block_space_manager.py b/vllm/core/placeholder_block_space_manager.py index a47e594518534..dcfae756c5ce5 100644 --- a/vllm/core/placeholder_block_space_manager.py +++ b/vllm/core/placeholder_block_space_manager.py @@ -32,11 +32,11 @@ def allocate(self, seq_group: SequenceGroup) -> None: # No actual allocation logic needed pass - def can_append_slots(self, seq_group: SequenceGroup, - num_lookahead_slots: int) -> bool: + def can_add_slots(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> bool: return True - def append_slots( + def add_slots( self, seq: Sequence, num_lookahead_slots: int, diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index b3d396f9cedda..03b557ec107e6 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -298,7 +298,8 @@ def seq_group_metadata_builder(): is_prompt=False, seq_data={}, sampling_params=None, - block_tables={}) + block_tables={}, + slot_mappings={}) def scheduler_running_outputs_builder(): @@ -595,7 +596,7 @@ def _schedule_running( # NOTE(woosuk): Preemption happens only when there is no available # slot to keep all the sequence groups in the RUNNING state. - while not self._can_append_slots(seq_group, enable_chunking): + while not self._can_add_slots(seq_group, enable_chunking): budget.subtract_num_batched_tokens(seq_group.request_id, num_running_tokens) num_running_seqs = seq_group.get_max_num_running_seqs() @@ -645,7 +646,7 @@ def _schedule_running( if not cont_loop: break else: - self._append_slots(seq_group, blocks_to_copy, enable_chunking) + self._add_slots(seq_group, blocks_to_copy, enable_chunking) is_prefill = seq_group.is_prefill() scheduled_seq_group: ScheduledSequenceGroup = \ @@ -764,7 +765,7 @@ def _schedule_swapped( curr_loras.add(lora_int_id) swapped_queue.popleft() self._swap_in(seq_group, blocks_to_swap_in) - self._append_slots(seq_group, blocks_to_copy, enable_chunking) + self._add_slots(seq_group, blocks_to_copy, enable_chunking) is_prefill = seq_group.is_prefill() if is_prefill: prefill_seq_groups.append( @@ -1008,7 +1009,7 @@ def _schedule_prefills( if enable_chunking and self.scheduler_config.is_multi_step: blocks_to_copy: List[Tuple[int, int]] = [] # init_multi_step_from_lookahead_slots happens in append_slots - self._append_slots(seq_group, blocks_to_copy, enable_chunking) + self._add_slots(seq_group, blocks_to_copy, enable_chunking) # This assert will trip when a copy-on-write happens. This is # not a concern as the very first sequence-group block # allocation happens above. Still, we have the assert to @@ -1250,9 +1251,9 @@ def _schedule(self) -> SchedulerOutputs: else: return self._schedule_default() - def _can_append_slots(self, seq_group: SequenceGroup, - enable_chunking: bool) -> bool: - """Determine whether or not we have enough space in the KV cache to + def _can_add_slots(self, seq_group: SequenceGroup, + enable_chunking: bool) -> bool: + """Determine whether we have enough space in the KV cache to continue generation of the sequence group. """ # It is True only for testing case to trigger artificial preemption. @@ -1267,11 +1268,11 @@ def _can_append_slots(self, seq_group: SequenceGroup, is_prefill, enable_chunking) if is_prefill and num_lookahead_slots > 0: - # Appending prefill slots only happens multi-step and + # Appending prefill slots only happens when multi-step and # chunked-prefill are enabled together. assert self.scheduler_config.is_multi_step and enable_chunking - return self.block_manager.can_append_slots( + return self.block_manager.can_add_slots( seq_group=seq_group, num_lookahead_slots=num_lookahead_slots) def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool: @@ -1309,11 +1310,14 @@ def schedule( self.cache_id].get_object() seq_group_metadata.seq_data.clear() seq_group_metadata.block_tables.clear() + seq_group_metadata.slot_mappings.clear() # seq_id -> SequenceData seq_data: Dict[int, SequenceData] = {} # seq_id -> physical block numbers block_tables: Dict[int, List[int]] = {} + # seq_id -> token to slot mappings + slot_mappings: Dict[int, List[int]] = {} if seq_group.is_encoder_decoder(): # Encoder associated with SequenceGroup @@ -1324,14 +1328,18 @@ def schedule( # Also managed at SequenceGroup level cross_block_table = self.block_manager.get_cross_block_table( seq_group) + cross_slot_mapping = self.block_manager.get_cross_slot_mapping( + seq_group) else: encoder_seq_data = None cross_block_table = None + cross_slot_mapping = None for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): seq_id = seq.seq_id seq_data[seq_id] = seq.data block_tables[seq_id] = self.block_manager.get_block_table(seq) + slot_mappings[seq_id] = self.block_manager.get_slot_mapping(seq) self.block_manager.access_all_blocks_in_seq(seq, now) if self.cache_config.enable_prefix_caching: @@ -1368,6 +1376,7 @@ def schedule( seq_data=seq_data, sampling_params=seq_group.sampling_params, block_tables=block_tables, + slot_mappings=slot_mappings, do_sample=do_sample, pooling_params=seq_group.pooling_params, token_chunk_size=token_chunk_size, @@ -1375,6 +1384,7 @@ def schedule( computed_block_nums=common_computed_block_nums, encoder_seq_data=encoder_seq_data, cross_block_table=cross_block_table, + cross_slot_mapping=cross_slot_mapping, state=seq_group.state, token_type_ids=seq_group.token_type_ids, # `multi_modal_data` will only be present for the 1st comm @@ -1398,6 +1408,7 @@ def schedule( seq_data_delta, seq_group.request_id, block_tables, + slot_mappings, is_prompt, do_sample=do_sample, token_chunk_size=token_chunk_size, @@ -1490,11 +1501,12 @@ def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None: for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): seq.status = SequenceStatus.RUNNING - def _append_slots(self, - seq_group: SequenceGroup, - blocks_to_copy: List[Tuple[int, int]], - enable_chunking: bool = False) -> None: - """Appends new slots to the sequences in the given sequence group. + def _add_slots(self, + seq_group: SequenceGroup, + blocks_to_copy: List[Tuple[int, int]], + enable_chunking: bool = False) -> None: + """Add new slots to the sequences in the given sequence group, which + could either append or insert the slots to/into the blocks. Args: seq_group (SequenceGroup): The sequence group containing the @@ -1523,7 +1535,7 @@ def _append_slots(self, seq_status = None for seq in seq_group.get_seqs(status=seq_status): - cows = self.block_manager.append_slots(seq, num_lookahead_slots) + cows = self.block_manager.add_slots(seq, num_lookahead_slots) if len(cows) > 0: blocks_to_copy.extend(cows) diff --git a/vllm/sequence.py b/vllm/sequence.py index 5857f656dfc10..8084b93b118d1 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -602,6 +602,9 @@ def get_num_computed_tokens(self) -> int: def is_prefill(self) -> bool: return self.data.stage == SequenceStage.PREFILL + def is_decode(self) -> bool: + return self.data.stage == SequenceStage.DECODE + def __repr__(self) -> str: return (f"Sequence(seq_id={self.seq_id}, " f"status={self.status.name}, " @@ -868,6 +871,9 @@ def is_finished(self) -> bool: def is_prefill(self) -> bool: return self.first_seq.is_prefill() + def is_decode(self) -> bool: + return self.first_seq.is_decode() + def __repr__(self) -> str: return (f"SequenceGroup(request_id={self.request_id}, " f"sampling_params={self.sampling_params}, " @@ -887,6 +893,7 @@ class SequenceGroupMetadataDelta( seq_data_delta: Dict[int, SequenceDataDelta] request_id: str block_tables: Dict[int, List[int]] + slot_mappings: Dict[int, List[int]] is_prompt: bool do_sample: bool = True token_chunk_size: Optional[int] = None @@ -909,6 +916,8 @@ class SequenceGroupMetadata( sampling_params: The sampling parameters used to generate the outputs. block_tables: The block tables. (Seq id -> list of physical block numbers) + slot_mappings: The token to slot mappings. (Seq id -> list of slot + mapping of each token) do_sample: True if sampling is required. Sampling is not required when e.g., prefill is chunked, and the current iteration only computes query tokens for prefill, we don't need sampling. @@ -937,6 +946,7 @@ class SequenceGroupMetadata( seq_data: Dict[int, SequenceData] sampling_params: Optional[SamplingParams] block_tables: Dict[int, List[int]] + slot_mappings: Dict[int, List[int]] do_sample: bool = True pooling_params: Optional[PoolingParams] = None lora_request: Optional[LoRARequest] = None @@ -951,6 +961,7 @@ class SequenceGroupMetadata( mm_processor_kwargs: Optional[Dict[str, Any]] = None encoder_seq_data: Optional[SequenceData] = None cross_block_table: Optional[List[int]] = None + cross_slot_mapping: Optional[List[int]] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None token_chunk_size: Optional[int] = None @@ -1003,6 +1014,7 @@ def apply_delta(self, self.seq_data[id].apply_delta(delta) assert self.request_id == sequence_group_metadata_delta.request_id self.block_tables = sequence_group_metadata_delta.block_tables + self.slot_mappings = sequence_group_metadata_delta.slot_mappings self.token_chunk_size = sequence_group_metadata_delta.token_chunk_size self.do_sample = sequence_group_metadata_delta.do_sample self.is_prompt = sequence_group_metadata_delta.is_prompt diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 4d5d918087be8..4a783224cd4d3 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -312,8 +312,10 @@ def profile_run(self) -> None: seq_data={group_id: decoder_dummy_data.seq_data}, sampling_params=sampling_params, block_tables=None, + slot_mappings=None, encoder_seq_data=encoder_dummy_data.seq_data, cross_block_table=None, + cross_slot_mapping=None, multi_modal_data=decoder_dummy_data.multi_modal_data or encoder_dummy_data.multi_modal_data, multi_modal_placeholders=decoder_dummy_data. @@ -420,12 +422,8 @@ def _prepare_encoder_model_input_tensors( # In embeddings, the block tables are {seq_id: None}. cross_slot_mapping.extend([PAD_SLOT_ID] * seq_len) else: - for i in range(0, seq_len): - block_number = seq_group_metadata.cross_block_table[ - i // self.block_size] - block_offset = i % self.block_size - slot = block_number * self.block_size + block_offset - cross_slot_mapping.append(slot) + seq_cross_slot_mapping = seq_group_metadata.cross_slot_mapping + cross_slot_mapping.extend(seq_cross_slot_mapping[0:seq_len]) # Build encoder input tokens encoder_input_tokens.extend(token_ids) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 1c6d1bbee78ee..e7825b07f0527 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -214,6 +214,7 @@ def __init__( seq_ids: List[int], is_prompt: bool, block_tables: Optional[Dict[int, List[int]]], + slot_mappings: Optional[Dict[int, List[int]]], computed_block_nums: List[int], n_seqs: int = 0, @@ -266,6 +267,7 @@ def __init__( self.request_id = request_id self.is_prompt = is_prompt self.block_tables = block_tables + self.slot_mappings = slot_mappings self.computed_block_nums = computed_block_nums self.n_seqs = n_seqs self.encoder_seq_len = encoder_seq_len @@ -405,6 +407,7 @@ def gen_inter_data_builder(self, num_seqs: int): seq_ids=[0] * num_seqs, is_prompt=True, block_tables=None, + slot_mappings=None, computed_block_nums=[]) def init_cached_inter_data(self, *args, **kwargs): @@ -599,14 +602,10 @@ def _compute_for_sliding_window(self, inter_data: InterDataForSeqGroup, if not inter_data.is_prompt and self.sliding_window is not None: # TODO(sang): This is a hack to make sliding window work with # paged attn. We can remove it if we make paged attn kernel - # to properly handle slinding window attn. + # to properly handle sliding window attn. curr_sliding_window_block = self.sliding_window_blocks - # number of elements in last block - suff_len = inter_data.seq_lens[seq_idx] % self.block_size sliding_seq_len = min(inter_data.seq_lens[seq_idx], - self.block_aligned_sliding_window + suff_len) - if suff_len > 0: - curr_sliding_window_block += 1 + self.block_aligned_sliding_window) inter_data.curr_sliding_window_blocks[ seq_idx] = curr_sliding_window_block @@ -738,6 +737,7 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): seq_ids=seq_ids, is_prompt=is_prompt, block_tables=seq_group_metadata.block_tables, + slot_mappings=seq_group_metadata.slot_mappings, computed_block_nums=seq_group_metadata.computed_block_nums, reinit=True, reinit_use_defaults=True, @@ -1299,6 +1299,7 @@ def profile_run(self) -> None: seq_data={group_id: dummy_data.seq_data}, sampling_params=sampling_params, block_tables=None, + slot_mappings=None, lora_request=dummy_lora_requests_per_seq[group_id] if dummy_lora_requests_per_seq else None, multi_modal_data=dummy_data.multi_modal_data,