Skip to content

Commit

Permalink
[Bugfix] Move the _touch(computed_blocks) call in the allocate_slots …
Browse files Browse the repository at this point in the history
…method to after the check for allocating new blocks.

Signed-off-by: Qianjun Zhou <[email protected]>
  • Loading branch information
sakunkun committed Dec 27, 2024
1 parent 6c6f7fe commit febcdf7
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,11 @@ def allocate_slots(
raise ValueError(
f"num_tokens must be greater than 0, got {num_tokens}")

num_required_blocks = cdiv(num_tokens, self.block_size)
if (num_required_blocks > self.free_block_queue.num_free_blocks):
# Cannot allocate new blocks.
return None

# Touch the computed blocks to make sure they won't be evicted.
if self.enable_caching:
self._touch(computed_blocks)
Expand All @@ -208,11 +213,6 @@ def allocate_slots(
"Computed blocks should be empty when "
"prefix caching is disabled")

num_required_blocks = cdiv(num_tokens, self.block_size)
if (num_required_blocks > self.free_block_queue.num_free_blocks):
# Cannot allocate new blocks.
return None

# Determine the number of new blocks to allocate considering
# preallocated blocks.
num_new_blocks = min(
Expand Down

0 comments on commit febcdf7

Please sign in to comment.