Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] Add attention sinks #3515

Draft
wants to merge 90 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
90 commits
Select commit Hold shift + click to select a range
7914879
temp
felixzhu555 Mar 15, 2024
5b672d9
wip
felixzhu555 Mar 18, 2024
b35d7ba
wip
felixzhu555 Mar 18, 2024
e90cb58
wip
felixzhu555 Mar 19, 2024
831f18b
wip
felixzhu555 Mar 19, 2024
c8d86e6
change q pos
felixzhu555 Mar 21, 2024
0bd7566
evict
felixzhu555 Mar 21, 2024
f0263a4
edit xformers
felixzhu555 Mar 31, 2024
15b68ca
wip
Mar 31, 2024
9fe1895
wip
Apr 1, 2024
595638d
wip
felixzhu555 Apr 1, 2024
217743d
wip
felixzhu555 Apr 4, 2024
fd83c78
wip
felixzhu555 Apr 10, 2024
12e0e97
pull from main
felixzhu555 Apr 13, 2024
a9b094c
wip
felixzhu555 Apr 14, 2024
25e599d
cuda illegal memory access
felixzhu555 Apr 14, 2024
d14b94e
wip
Apr 17, 2024
8bb1840
cache current prerope key inside llama instead of xformers
felixzhu555 Apr 18, 2024
339305b
early eos
felixzhu555 Apr 18, 2024
1157cf3
fix small bugs
felixzhu555 Apr 18, 2024
0f0a414
wip
felixzhu555 Apr 21, 2024
6f01606
fix prefill
felixzhu555 Apr 22, 2024
740cbdb
wip
felixzhu555 Apr 23, 2024
15d586a
starting to work!
felixzhu555 Apr 24, 2024
c4a50b4
blockwise speedup
felixzhu555 Apr 24, 2024
455c814
wip
felixzhu555 Apr 25, 2024
ee12294
start removing loop over blocks
felixzhu555 May 15, 2024
d29b559
wip
felixzhu555 May 15, 2024
94ebe4d
wip
felixzhu555 May 16, 2024
899a7b3
wip: after refactor, generation abruptly ends
felixzhu555 May 17, 2024
016a6c6
speedup to 4 tok/s done
felixzhu555 May 18, 2024
2186c13
cache phys_bnums, 3x speedup
felixzhu555 May 18, 2024
e7acfbe
move logic out of xformers.py
felixzhu555 May 19, 2024
18042c6
refactor into new layer
felixzhu555 May 19, 2024
d2af329
pull from main
felixzhu555 May 21, 2024
8fe15d4
wip
felixzhu555 May 21, 2024
3ae06f5
add use_attention_sinks args
felixzhu555 May 21, 2024
67c3bdf
investigating eviction issue
felixzhu555 May 23, 2024
e09296b
flash attn works
felixzhu555 May 24, 2024
8f152d5
remove eviction & some refactoring
felixzhu555 May 30, 2024
a766775
start mixtral
felixzhu555 Jun 1, 2024
1e44278
refactor, start alibi, try falcon/bloom
felixzhu555 Jun 2, 2024
7413279
tiny
felixzhu555 Jun 2, 2024
05d7aa9
add mpt
felixzhu555 Jun 3, 2024
19a90f6
alibi not working
felixzhu555 Jun 3, 2024
34df763
fix seq len bug -> eviction and alibi work
felixzhu555 Jun 4, 2024
afb754c
eviction moved to block manager -> rope works, alibi not yet
felixzhu555 Jun 5, 2024
d7db6e1
fix alibi bug
felixzhu555 Jun 5, 2024
3d0929c
beam search not supported
felixzhu555 Jun 5, 2024
13b48c4
small fix
felixzhu555 Jun 5, 2024
9475536
pull main
felixzhu555 Jun 5, 2024
88a77d3
refactor models
felixzhu555 Jun 5, 2024
b3cfffb
pull main
felixzhu555 Jun 6, 2024
3e229a0
small
felixzhu555 Jun 7, 2024
b834de8
tests wip
felixzhu555 Jun 8, 2024
56b448a
tests failing
felixzhu555 Jun 10, 2024
7d9723c
wip
felixzhu555 Jun 11, 2024
2f92168
test correctness done
felixzhu555 Jun 12, 2024
c8416a0
add attn backend to tests, add eviction test
felixzhu555 Jun 12, 2024
b31ae95
small
felixzhu555 Jun 12, 2024
f241532
start refactor
felixzhu555 Jun 19, 2024
5c7f802
add wrapper method
felixzhu555 Jun 19, 2024
143db31
wip
felixzhu555 Jun 19, 2024
0722ff0
pull main
felixzhu555 Jun 20, 2024
e0848e3
refactor wip
felixzhu555 Jun 20, 2024
7abb285
fix test
felixzhu555 Jun 21, 2024
5bf0d5c
small
felixzhu555 Jun 21, 2024
ae31b1d
chunked prefill wip
felixzhu555 Jun 21, 2024
779b2a3
wip
felixzhu555 Jun 22, 2024
d527920
cuda mem error
felixzhu555 Jun 22, 2024
87bd485
chunked prefill working
felixzhu555 Jun 23, 2024
0a1abf8
wip
felixzhu555 Jun 23, 2024
08fd48f
fix paxos paper
felixzhu555 Jun 25, 2024
65f5f6d
wip
felixzhu555 Jun 26, 2024
cb12d5f
Merge branch 'main' of https://github.com/vllm-project/vllm into add_…
felixzhu555 Jun 26, 2024
fdc1365
chunked prefill for alibi
felixzhu555 Jun 27, 2024
da75ff6
add some docstrings
felixzhu555 Jun 28, 2024
fa8a253
fix test
felixzhu555 Jun 28, 2024
1763a44
pull main
felixzhu555 Jun 29, 2024
ef65724
fix after removal of logical block table
felixzhu555 Jul 16, 2024
38bd15f
change pos arange
felixzhu555 Jul 17, 2024
b0b8d0b
pull main
felixzhu555 Jul 17, 2024
7de1a21
small
felixzhu555 Aug 4, 2024
1ecec38
small
felixzhu555 Aug 4, 2024
5f03373
pull main, breaking changes to be fixed
felixzhu555 Aug 4, 2024
2da86a8
fix updates from pull main
felixzhu555 Aug 4, 2024
71ca701
refactor forward: remove rem logic, move torch ops out of loop
felixzhu555 Aug 4, 2024
bce7902
fix flash_attn.py
felixzhu555 Aug 4, 2024
be779fb
fix tests
felixzhu555 Aug 4, 2024
9d97b8d
pull main
felixzhu555 Aug 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions temp/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import json
from typing import List, Tuple
from vllm import EngineArgs, LLMEngine, SamplingParams, RequestOutput
from transformers import AutoTokenizer


MAX_GEN_TOKENS = 200


def get_chat_prompts(file_path="./mt_bench.jsonl") -> List[Tuple[str, SamplingParams]]:
list_data_dict = []
with open(file_path, "r") as f:
for line in f:
list_data_dict.append(json.loads(line))

prompts = []
for sample in list_data_dict:
prompts += sample["turns"]

return [(prompt, SamplingParams(max_tokens=MAX_GEN_TOKENS)) for prompt in prompts]


def get_long_prompt(file_path="./paxos_paper.txt") -> Tuple[str, SamplingParams]:
# this file is 4060 tokens
with open(file_path, "r") as f:
prompt = f.read()

return [(prompt, SamplingParams(max_tokens=MAX_GEN_TOKENS))]


def process_requests(engine: LLMEngine,
test_prompts: List[Tuple[str, SamplingParams]],
tokenizer):
"""Continuously process a list of prompts and handle the outputs."""
request_id = 0

while test_prompts or engine.has_unfinished_requests():
if test_prompts:
prompt, sampling_params = test_prompts.pop(0)
engine.add_request(str(request_id), prompt, sampling_params)
request_id += 1

request_outputs: List[RequestOutput] = engine.step()

for request_output in request_outputs:
if request_output.finished:
# print("\nPROMPT:")
# print(request_output.prompt)

text = request_output.outputs[0].text
num_tokens = len(tokenizer.tokenize(text))
print(f"\nOUTPUT: ({num_tokens} tokens)")
print(text, "\n")
print(request_output.outputs)


def main():
# context length 4096
model = "lmsys/vicuna-7b-v1.5"
args = EngineArgs(
model=model,
enforce_eager=True,
max_model_len=4096,
block_size=16
)

engine = LLMEngine.from_engine_args(args)
print("max model len", engine.scheduler_config.max_model_len)
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
# prompts = get_chat_prompts()
prompts = get_long_prompt()
process_requests(engine, prompts, tokenizer)


if __name__ == "__main__":
main()
79 changes: 79 additions & 0 deletions temp/mt_bench.jsonl

Large diffs are not rendered by default.

450 changes: 450 additions & 0 deletions temp/out.txt

Large diffs are not rendered by default.

259 changes: 259 additions & 0 deletions temp/paxos_paper.txt

Large diffs are not rendered by default.

30 changes: 30 additions & 0 deletions vllm/attention/backends/xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def forward(
kv_cache: Optional[torch.Tensor],
attn_metadata: AttentionMetadata[XFormersMetadata],
kv_scale: float,
key_original: Optional[torch.Tensor],
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.

Expand All @@ -189,6 +190,9 @@ def forward(
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)

use_attn_sinks = True
value_copy = value.clone() if use_attn_sinks else None

if kv_cache is not None:
key_cache, value_cache = PagedAttention.split_kv_cache(
Expand Down Expand Up @@ -229,6 +233,19 @@ def forward(
query, key, value, prefill_meta)
assert out.shape == output[:num_prefill_tokens].shape
output[:num_prefill_tokens] = out

# in prefill, rewrite all keys with original
if use_attn_sinks and kv_cache is not None and key_original is not None:
key_original = key_original.view(-1, self.num_kv_heads, self.head_size)
PagedAttention.write_to_paged_cache(
key_original,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
attn_metadata.kv_cache_dtype,
kv_scale
)
else:
# prefix-enabled attention
# TODO(Hai) this triton kernel has regression issue (broke) to
Expand Down Expand Up @@ -265,6 +282,19 @@ def forward(
kv_scale,
)

# attention sinks: revert cur key in cache to pre-rotated state
if use_attn_sinks and kv_cache is not None:
key_original = key_original.view(-1, self.num_kv_heads, self.head_size)
PagedAttention.write_to_paged_cache(
key_original,
value_copy,
key_cache,
value_cache,
attn_metadata.slot_mapping,
attn_metadata.kv_cache_dtype,
kv_scale
)

# Reshape the output tensor.
return output.view(-1, self.num_heads * self.head_size)

Expand Down
3 changes: 2 additions & 1 deletion vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def forward(
kv_cache: Optional[torch.Tensor],
attn_metadata: AttentionMetadata[AttentionMetadataPerStage],
kv_scale: float = 1.0,
key_original: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.impl.forward(query, key, value, kv_cache, attn_metadata,
kv_scale)
kv_scale, key_original)
3 changes: 2 additions & 1 deletion vllm/core/block_manager_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ def allocate(self,

def free(self, block: PhysicalTokenBlock) -> None:
if block.ref_count == 0:
raise ValueError(f"Double free! {block} is already freed.")
return
# raise ValueError(f"Double free! {block} is already freed.")
block.ref_count -= 1
if block.ref_count == 0:
assert block.block_hash not in self.evictor
Expand Down
11 changes: 11 additions & 0 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,17 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
block_tables[seq_id] = self.block_manager.get_block_table(seq)
self.block_manager.access_all_blocks_in_seq(seq, now)

use_attn_sinks = False
max_context_len = self.scheduler_config.max_model_len
seq_len = seq.get_len()
block_size = 16 # where do we get this in Scheduler?
if use_attn_sinks and seq_len > max_context_len:
# 0th block is attention sink
block_idx_to_free = (seq_len - max_context_len - 1) // block_size + 1
block_to_free = self.block_manager.block_tables[seq_id][block_idx_to_free]
if block_to_free.ref_count > 0:
self.block_manager.gpu_allocator.free(block_to_free)

common_computed_block_nums = (
self.block_manager.get_common_computed_block_ids(
seq_group.get_seqs(status=SequenceStatus.RUNNING)))
Expand Down
11 changes: 8 additions & 3 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,7 @@ def step(self) -> List[RequestOutput]:
else:
output = []

print("\t\tFREE BLOCKS", self.scheduler.block_manager.get_num_free_gpu_blocks())
return self._process_model_outputs(output, scheduler_outputs)

def do_log_stats(self) -> None:
Expand Down Expand Up @@ -839,10 +840,14 @@ def _check_stop(self, seq: Sequence, new_char_count: int,
seq.stop_reason = stop_str
return

'''FIXME
Comment this out to temporarily bypass context length.
Output should start showing gibberish.
'''
# Check if the sequence has reached max_model_len.
if seq.get_len() > self.scheduler_config.max_model_len:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
# if seq.get_len() > self.scheduler_config.max_model_len:
# seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
# return

# Check if the sequence has reached max_tokens.
if seq.get_output_len() == sampling_params.max_tokens:
Expand Down
33 changes: 33 additions & 0 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,39 @@ def _forward(
key = key.flatten(-2)
return query, key

def _forward_single(self, positions, x, offsets=None):
"""
Same thing as above, except only for either q or k.
PyTorch-native implementation equivalent to forward()."""
x = x.view(*x.shape[:-1], -1, self.head_size)

x_rot = x[..., :self.rotary_dim]
if self.rotary_dim < self.head_size:
x_pass = x[..., self.rotary_dim:]

self.cos_sin_cache = self.cos_sin_cache.to(positions.device)
cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
if offsets is not None else positions]
cos, sin = cos_sin.chunk(2, dim=-1)
if self.is_neox_style:
# NOTE(woosuk): Here we assume that the positions tensor has the
# shape [batch_size, seq_len].
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
else:
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)

rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
x_rot = x_rot * cos + rotate_fn(x_rot) * sin

if self.rotary_dim < self.head_size:
x = torch.cat((x_rot, x_pass), dim=-1)
else:
x = x_rot
x = x.flatten(-2).squeeze(0) # very sus!
return x

def forward(
self,
positions: torch.Tensor,
Expand Down
139 changes: 133 additions & 6 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@
hf_model_weights_iterator,
kv_cache_scales_loader)
from vllm.sequence import SamplerOutput
from vllm.utils import is_hip
from vllm.utils import is_hip, make_tensor_with_pad
from vllm.attention.ops.paged_attn import PagedAttention


class LlamaMLP(nn.Module):
Expand Down Expand Up @@ -163,11 +164,137 @@ def forward(
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata,
self.kv_scale)
output, _ = self.o_proj(attn_output)
return output
felixzhu555 marked this conversation as resolved.
Show resolved Hide resolved
# q k v all have shape [num_tokens, num_heads * head_size] i.e. [1, 4096] for decode

use_attn_sinks = True
llama_context_len = 4096
norm = lambda x: torch.linalg.norm(x).item()

if not use_attn_sinks:
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata,
self.kv_scale)
output, _ = self.o_proj(attn_output)
return output

# what if metadata has both prefill and decode?
if attn_metadata.prefill_metadata is not None:
# for prefill, storing original keys happens in xformers.py
k_original = k.clone()
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata,
self.kv_scale, k_original)
output, _ = self.o_proj(attn_output)
return output

elif attn_metadata.decode_metadata is not None:
k_original = k.clone()
# streamingLLM: use pos in cache
positions = torch.clamp(positions, max=llama_context_len - 1)
q, k = self.rotary_emb(positions, q, k)

# key cache reshape: [num_blocks, num_heads, head_size/x, block_size, x]
key_cache, _ = PagedAttention.split_kv_cache(kv_cache, self.num_kv_heads, self.head_dim)
block_size = key_cache.shape[-2]

original_block_tables = attn_metadata.decode_metadata.block_tables.clone()
block_tables_tensor = attn_metadata.decode_metadata.block_tables
block_tables: List[List[int]] = []
context_lens = attn_metadata.decode_metadata.context_lens.tolist()

# batch size = num sequences
batch_size = block_tables_tensor.shape[0]
original_keys: List[dict] = []
for i in range(batch_size):
# see paged_attn.py line 19 for context_lens definition
num_past_tokens = context_lens[i] - 1
within_context_len = num_past_tokens < llama_context_len

past_keys = {} # logic bnum => block of keys
block_table = block_tables_tensor[i]
start = 0 if within_context_len else num_past_tokens - llama_context_len + 1 + block_size
end = num_past_tokens
abs_pos = start
# loop should have 4096 - 1 iterations if not within_context_len
while abs_pos < end:
logic_bnum = abs_pos // block_size
phys_bnum = block_table[logic_bnum]
offset_start = abs_pos % block_size
offset_end = block_size if end - abs_pos > block_size else end - abs_pos
num_tokens = offset_end - offset_start

# past_key shape: [num_heads, head_size/x, num_tokens, x]
past_key = key_cache[phys_bnum, :, :, offset_start : offset_end, :]
past_keys[logic_bnum] = past_key.clone()

# rotate k based on new relative pos
p = abs_pos if within_context_len else abs_pos - start + block_size
pos = [p + off for off in range(num_tokens)] # sus
pos = torch.tensor(pos, device=positions.device)

# sus reshapes
past_key = past_key.permute((2, 0, 1, 3)).reshape(num_tokens, -1)
past_key = self.rotary_emb._forward_single(pos, past_key)
past_key = past_key.reshape(num_tokens, key_cache.shape[1], key_cache.shape[2], key_cache.shape[4])
key_cache[phys_bnum, :, :, offset_start : offset_end, :] = past_key.permute((1, 2, 0, 3))

# rotary emb kernel has almost 2x speedup BUT it's incorrect for some reason

abs_pos += num_tokens

original_keys.append(past_keys)

if not within_context_len:
blocks_to_ignore = (num_past_tokens - llama_context_len) // block_size + 1
# block_table[0] is attention sink
capped_block_table = [block_table[0].item()] + block_table[blocks_to_ignore + 1:].tolist()
block_tables.append(capped_block_table)

if block_tables:
attn_metadata.decode_metadata.block_tables = make_tensor_with_pad(
block_tables,
max_len=llama_context_len // block_size,
pad=0,
dtype=torch.int,
device=original_block_tables.device
)
attn_metadata.decode_metadata.context_lens = torch.clamp(
attn_metadata.decode_metadata.context_lens,
max=llama_context_len
)

# compute attention in kernel
attn_output = self.attn(q, k, v, kv_cache, attn_metadata,
self.kv_scale, k_original)

# put original keys back in cache
for i in range(batch_size):
num_past_tokens = context_lens[i] - 1
# if num_past_tokens < llama_context_len: continue

past_keys = original_keys[i]
block_table = original_block_tables[i]
start = 0 if within_context_len else num_past_tokens - llama_context_len + 1 + block_size
end = num_past_tokens
abs_pos = start
while abs_pos < end:
logic_bnum = abs_pos // block_size
phys_bnum = block_table[logic_bnum]
offset_start = abs_pos % block_size
offset_end = block_size if end - abs_pos > block_size else end - abs_pos
num_tokens = offset_end - offset_start

key_cache[phys_bnum, :, :, offset_start : offset_end, :] = past_keys[logic_bnum]
abs_pos += num_tokens

# revert block_tables and context_lens inside metadata
# so that next attn layer starts with same fields
attn_metadata.decode_metadata.block_tables = original_block_tables
attn_metadata.decode_metadata.context_lens = torch.tensor(
context_lens, dtype=torch.int, device=positions.device)

output, _ = self.o_proj(attn_output)
return output


class LlamaDecoderLayer(nn.Module):
Expand Down
Loading