Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch.compile] Hide KV cache behind torch.compile boundary #11677

Merged
merged 22 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions tests/kernels/test_encoder_decoder_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,18 @@ class that Attention will automatically select when it is constructed.
torch.tensor([], dtype=torch.float32, device=CUDA_DEVICE))

# Construct KV cache
kv_cache = make_kv_cache(test_pt.num_blocks,
test_pt.num_heads,
test_pt.head_size,
test_pt.block_size,
device=CUDA_DEVICE,
backend=test_pt.backend_name)
if test_pt.attn_type in (AttentionType.DECODER,
AttentionType.ENCODER_DECODER):
kv_cache = make_kv_cache(test_pt.num_blocks,
test_pt.num_heads,
test_pt.head_size,
test_pt.block_size,
device=CUDA_DEVICE,
backend=test_pt.backend_name)
else:
kv_cache = torch.tensor([])

attn.kv_cache = [kv_cache]
return TestResources(scale, attn, kv_cache)


Expand Down
85 changes: 83 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
import torch
from vllm_test_utils import monitor

from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.utils import (FlexibleArgumentParser, PlaceholderModule,
StoreBoolean, deprecate_kwargs, get_open_port,
memory_profiling, merge_async_iterators, supports_kw)
StoreBoolean, bind_kv_cache, deprecate_kwargs,
get_open_port, memory_profiling, merge_async_iterators,
supports_kw)

from .utils import error_on_warning, fork_new_process_for_each_test

Expand Down Expand Up @@ -325,6 +327,85 @@ def measure_current_non_torch():
lib.cudaFree(handle2)


def test_bind_kv_cache():
from vllm.attention import Attention

ctx = {
'layers.0.self_attn': Attention(32, 128, 0.1),
'layers.1.self_attn': Attention(32, 128, 0.1),
'layers.2.self_attn': Attention(32, 128, 0.1),
'layers.3.self_attn': Attention(32, 128, 0.1),
}
kv_cache = [
torch.zeros((1, )),
torch.zeros((1, )),
torch.zeros((1, )),
torch.zeros((1, )),
]
bind_kv_cache(ctx, [kv_cache])
assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[0]
assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[1]
assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[2]
assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[3]

def test_bind_kv_cache_non_attention():
from vllm.attention import Attention

# example from Jamba PP=2
ctx = {
'model.layers.20.attn': Attention(32, 128, 0.1),
'model.layers.28.attn': Attention(32, 128, 0.1),
}
kv_cache = [
torch.zeros((1, )),
torch.zeros((1, )),
]
bind_kv_cache(ctx, [kv_cache])
assert ctx['model.layers.20.attn'].kv_cache[0] is kv_cache[0]
assert ctx['model.layers.28.attn'].kv_cache[0] is kv_cache[1]


def test_bind_kv_cache_encoder_decoder():
from vllm.attention import Attention, AttentionType

# example from bart
ctx = {
'encoder.layers.0.self_attn.attn':
Attention(32, 128, 0.1, attn_type=AttentionType.ENCODER),
'decoder.layers.0.encoder_attn.attn':
Attention(32, 128, 0.1, attn_type=AttentionType.ENCODER_DECODER),
'decoder.layers.0.self_attn.attn':
Attention(32, 128, 0.1, attn_type=AttentionType.DECODER),
}

kv_cache = [
torch.zeros((1, )),
]
encoder_kv_cache = ctx['encoder.layers.0.self_attn.attn'].kv_cache

bind_kv_cache(ctx, [kv_cache])
assert ctx['encoder.layers.0.self_attn.attn'].kv_cache is encoder_kv_cache
assert ctx['decoder.layers.0.encoder_attn.attn'].kv_cache[0] is kv_cache[0]
assert ctx['decoder.layers.0.self_attn.attn'].kv_cache[0] is kv_cache[0]


def test_bind_kv_cache_pp():
cfg = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=2))
with set_current_vllm_config(cfg):
from vllm.attention import Attention

ctx = {
'layers.0.self_attn': Attention(32, 128, 0.1),
}
kv_cache = [
[torch.zeros((1, ))],
[torch.zeros((1, ))]
]
bind_kv_cache(ctx, kv_cache)
assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[0][0]
assert ctx['layers.0.self_attn'].kv_cache[1] is kv_cache[1][0]


def test_placeholder_module_error_handling():
placeholder = PlaceholderModule("placeholder_1234")

Expand Down
3 changes: 3 additions & 0 deletions tests/v1/engine/test_engine_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
from transformers import AutoTokenizer

from tests.utils import fork_new_process_for_each_test
from vllm import SamplingParams
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
Expand Down Expand Up @@ -36,6 +37,7 @@ def make_request() -> EngineCoreRequest:
)


@fork_new_process_for_each_test
def test_engine_core(monkeypatch):

with monkeypatch.context() as m:
Expand Down Expand Up @@ -138,6 +140,7 @@ def test_engine_core(monkeypatch):
assert len(engine_core.scheduler.running) == 0


@fork_new_process_for_each_test
def test_engine_core_advanced_sampling(monkeypatch):
"""
A basic end-to-end test to verify that the engine functions correctly
Expand Down
3 changes: 3 additions & 0 deletions tests/v1/engine/test_engine_core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
from transformers import AutoTokenizer

from tests.utils import fork_new_process_for_each_test
from vllm import SamplingParams
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
Expand Down Expand Up @@ -75,6 +76,7 @@ async def loop_until_done_async(client: EngineCoreClient, outputs: Dict):
break


@fork_new_process_for_each_test
@pytest.mark.parametrize("multiprocessing_mode", [True, False])
def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):

Expand Down Expand Up @@ -143,6 +145,7 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
client.abort_requests([request.request_id])


@fork_new_process_for_each_test
@pytest.mark.asyncio
async def test_engine_core_client_asyncio(monkeypatch):

Expand Down
29 changes: 17 additions & 12 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,13 @@ def __init__(
compilation_config.static_forward_context[prefix] = self
self.layer_name = prefix
self.attn_type = attn_type
# use a placeholder kv cache tensor during init, which will be replaced
# by bind_kv_cache
# this variable will not be accessed if use_direct_call is True
self.kv_cache = [
torch.tensor([]) for _ in range(get_current_vllm_config(
).parallel_config.pipeline_parallel_size)
]

def forward(
self,
Expand Down Expand Up @@ -148,11 +155,11 @@ def forward(
if value is not None:
value = value.view(-1, self.num_kv_heads, self.head_size)
torch.ops.vllm.unified_attention_with_output(
query, key, value, output, kv_cache, self.layer_name)
query, key, value, output, self.layer_name)
return output.view(-1, hidden_size)
else:
return torch.ops.vllm.unified_attention(query, key, value,
kv_cache, self.layer_name)
self.layer_name)

def extra_repr(self) -> str:
s = f"head_size={self.impl.head_size}" # type: ignore
Expand Down Expand Up @@ -230,12 +237,12 @@ def unified_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.dynamic_forward_context
self = forward_context.static_forward_context[layer_name]
attn_metadata = forward_context.attn_metadata
self = forward_context.attn_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
return self.impl.forward(query, key, value, kv_cache, attn_metadata,
self._k_scale, self._v_scale)

Expand All @@ -244,7 +251,6 @@ def unified_attention_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
return torch.empty_like(query).contiguous()
Expand All @@ -253,7 +259,7 @@ def unified_attention_fake(
direct_register_custom_op(
op_name="unified_attention",
op_func=unified_attention,
mutates_args=["kv_cache"],
mutates_args=[],
fake_impl=unified_attention_fake,
dispatch_key=current_platform.dispatch_key,
)
Expand All @@ -264,12 +270,12 @@ def unified_attention_with_output(
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
kv_cache: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.dynamic_forward_context
self = forward_context.static_forward_context[layer_name]
attn_metadata = forward_context.attn_metadata
self = forward_context.attn_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(query,
key,
value,
Expand All @@ -285,7 +291,6 @@ def unified_attention_with_output_fake(
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
kv_cache: torch.Tensor,
layer_name: str,
) -> None:
return
Expand All @@ -294,7 +299,7 @@ def unified_attention_with_output_fake(
direct_register_custom_op(
op_name="unified_attention_with_output",
op_func=unified_attention_with_output,
mutates_args=["kv_cache", "output"],
mutates_args=["output"],
fake_impl=unified_attention_with_output_fake,
dispatch_key=current_platform.dispatch_key,
)
1 change: 0 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2780,7 +2780,6 @@ def model_post_init(self, __context: Any) -> None:
compilation_time: float = PrivateAttr

# Per-model forward context
# Mainly used to store attention cls
# Map from layer name to the attention cls
static_forward_context: Dict[str, Any] = PrivateAttr

Expand Down
33 changes: 20 additions & 13 deletions vllm/forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, Optional

import torch

import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.logger import init_logger

if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata

logger = init_logger(__name__)

track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
Expand All @@ -21,9 +24,12 @@

@dataclass
class ForwardContext:
static_forward_context: Dict[str, Any]
# copy from vllm_config.compilation_config.static_forward_context
attn_layers: Dict[str, Any]
# TODO: extend to support per-layer dynamic forward context
dynamic_forward_context: Any
attn_metadata: "AttentionMetadata" # set dynamically for each forward pass
# TODO: remove after making all virtual_engines share the same kv cache
virtual_engine: int # set dynamically for each forward pass


_forward_context: Optional[ForwardContext] = None
Expand All @@ -38,34 +44,35 @@ def get_forward_context() -> ForwardContext:


@contextmanager
def set_forward_context(context: Any, vllm_config: VllmConfig):
def set_forward_context(attn_metadata: Any,
vllm_config: VllmConfig,
virtual_engine: int = 0):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
Here we can inject common logic for every model forward pass.
"""
global forward_start_time
need_to_track_batchsize = track_batchsize and context is not None
need_to_track_batchsize = track_batchsize and attn_metadata is not None
if need_to_track_batchsize:
forward_start_time = time.perf_counter()
global _forward_context
prev_context = _forward_context
_forward_context = ForwardContext(
static_forward_context=vllm_config.compilation_config.
static_forward_context,
dynamic_forward_context=context)
attn_layers=vllm_config.compilation_config.static_forward_context,
virtual_engine=virtual_engine,
attn_metadata=attn_metadata)
try:
yield
finally:
global batchsize_counter
global last_logging_time, batchsize_logging_interval
if need_to_track_batchsize:
if hasattr(context, "num_prefill_tokens"):
if hasattr(attn_metadata, "num_prefill_tokens"):
# for v0 attention backends
batchsize = context.num_prefill_tokens + \
context.num_decode_tokens
batchsize = attn_metadata.num_prefill_tokens + \
attn_metadata.num_decode_tokens
else:
# for v1 attention backends
batchsize = context.num_input_tokens
batchsize = attn_metadata.num_input_tokens
# we use synchronous scheduling right now,
# adding a sync point here should not affect
# scheduling of the next batch
Expand Down
35 changes: 35 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2138,3 +2138,38 @@ def get_mp_context():
_check_multiproc_method()
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
return multiprocessing.get_context(mp_method)


def bind_kv_cache(
ctx: Dict[str, Any],
kv_cache: List[List[torch.Tensor]], # [virtual_engine][layer_index]
) -> None:
# Bind the kv_cache tensor to Attention modules, similar to
# ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)]
# Special things handled here:
# 1. Some models have non-attention layers, e.g., Jamba
# 2. Pipeline parallelism, each rank only has a subset of layers
# 3. Encoder attention has no kv cache
# 4. Encoder-decoder models, encoder-decoder attention and decoder-only
# attention of the same layer (e.g., bart's decoder.layers.1.self_attn
# and decoder.layers.1.encoder_attn) is mapped to the same kv cache
# tensor
from vllm.attention import AttentionType
from vllm.model_executor.models.utils import extract_layer_index
layer_need_kv_cache = [
layer_name for layer_name in ctx
if ctx[layer_name].attn_type in (AttentionType.DECODER,
AttentionType.ENCODER_DECODER)
]
layer_index_sorted = sorted(
set(
extract_layer_index(layer_name)
for layer_name in layer_need_kv_cache))
for layer_name in layer_need_kv_cache:
kv_cache_idx = layer_index_sorted.index(
extract_layer_index(layer_name))
forward_ctx = ctx[layer_name]
assert len(forward_ctx.kv_cache) == len(kv_cache)
for ve, ve_kv_cache in enumerate(kv_cache):
assert forward_ctx.kv_cache[ve].numel() == 0
forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx]
Loading
Loading