From 2382c2400d8e4b531725caadd1a212e0f9981bb7 Mon Sep 17 00:00:00 2001 From: Lily Liu Date: Wed, 7 Aug 2024 16:26:52 -0700 Subject: [PATCH] [Kernel] Fix Flashinfer Correctness (#7284) --- vllm/attention/backends/flashinfer.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 03188164a9637..64b1f4a89f23c 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -127,6 +127,7 @@ def __post_init__(self): raise ValueError( f"Only {supported_head_sizes} are supported for head_dim,", f"received {self.head_dim}.") + self.is_profile_run = is_block_tables_empty(self.block_tables) def begin_forward(self): if self.num_prefill_tokens > 0: @@ -140,11 +141,14 @@ def begin_forward(self): assert self.paged_kv_last_page_len is not None batch_size = self.query_start_loc.shape[0] - 1 assert batch_size >= 0 - # The prefill stage does not read kv cache. + # The profile run does not read kv cache. # Both paged_kv_indices and paged_kv_last_page_len are empty. # paged_kv_indptr is a zero tensor with size batch_size + 1. - self.paged_kv_indptr = torch.zeros(batch_size + 1, - device=self.device) + if self.is_profile_run: + self.paged_kv_indptr = torch.zeros(batch_size + 1, + device=self.device) + else: + self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( self.device) self.paged_kv_indices = self.paged_kv_indices.to(self.device)