diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index daff76051a956..9c25b2cc2ba97 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -20,6 +20,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, compute_slot_mapping_start_idx, is_block_tables_empty) +from vllm.attention.ops.paged_attn import PagedAttention from vllm.sequence import SequenceGroupMetadata from vllm.utils import get_kv_cache_torch_dtype, make_tensor_with_pad @@ -61,14 +62,14 @@ def swap_blocks( dst_kv_cache: torch.Tensor, src_to_dst: torch.Tensor, ) -> None: - raise NotImplementedError + PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], src_to_dists: torch.Tensor, ) -> None: - raise NotImplementedError + PagedAttention.copy_blocks(kv_caches, src_to_dists) @staticmethod def get_supported_head_sizes() -> List[int]: