From e76466dde2bc9525d55165ceaa600d298c7bf773 Mon Sep 17 00:00:00 2001 From: Alexander Matveev <59768536+alexm-neuralmagic@users.noreply.github.com> Date: Wed, 17 Jul 2024 17:30:28 -0400 Subject: [PATCH 01/13] [Core] draft_model_runner: Implement prepare_inputs on GPU for advance_step (#6338) --- CMakeLists.txt | 1 + csrc/ops.h | 5 + csrc/prepare_inputs/advance_step.cu | 131 +++++++++ csrc/prepare_inputs/advance_step.cuh | 19 ++ csrc/torch_bindings.cpp | 4 + tests/spec_decode/e2e/conftest.py | 1 + tests/spec_decode/test_multi_step_worker.py | 48 +++ vllm/_custom_ops.py | 12 + vllm/model_executor/layers/sampler.py | 147 +++++++--- vllm/model_executor/sampling_metadata.py | 10 + vllm/spec_decode/draft_model_runner.py | 305 +++++++++++++++----- vllm/spec_decode/multi_step_worker.py | 15 +- 12 files changed, 568 insertions(+), 130 deletions(-) create mode 100644 csrc/prepare_inputs/advance_step.cu create mode 100644 csrc/prepare_inputs/advance_step.cuh diff --git a/CMakeLists.txt b/CMakeLists.txt index ced73ca03bfbc..335623bd2677d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -151,6 +151,7 @@ set(VLLM_EXT_SRC "csrc/quantization/fp8/common.cu" "csrc/cuda_utils_kernels.cu" "csrc/moe_align_block_size_kernels.cu" + "csrc/prepare_inputs/advance_step.cu" "csrc/torch_bindings.cpp") if(VLLM_GPU_LANG STREQUAL "CUDA") diff --git a/csrc/ops.h b/csrc/ops.h index f9feb3deff5e4..1e94a9f45ef08 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -52,6 +52,11 @@ void gelu_fast(torch::Tensor& out, torch::Tensor& input); void gelu_quick(torch::Tensor& out, torch::Tensor& input); +void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size, + torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids, + torch::Tensor& input_positions, torch::Tensor& seq_lens, + torch::Tensor& slot_mapping, torch::Tensor& block_tables); + #ifndef USE_ROCM torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, const torch::Tensor& codebooks, diff --git a/csrc/prepare_inputs/advance_step.cu b/csrc/prepare_inputs/advance_step.cu new file mode 100644 index 0000000000000..0e537ddd6c4cd --- /dev/null +++ b/csrc/prepare_inputs/advance_step.cu @@ -0,0 +1,131 @@ +/* + * The goal of this GPU kernel is to advance input tensors on the GPU directly + * PR: https://github.com/vllm-project/vllm/pull/6338 + * Current restrictions: + * 1. Specialized for DraftModelRunner + * 2. Supports flash_attn only + */ + +#include "advance_step.cuh" + +namespace prepare_inputs { + +// +template +__global__ void advance_step_kernel(int num_seqs, int num_queries, + int block_size, long* input_tokens_ptr, + long const* sampled_token_ids_ptr, + long* input_positions_ptr, + int* seq_lens_ptr, long* slot_mapping_ptr, + int const* block_tables_ptr, + int64_t const block_tables_stride) { + int num_query_blocks = div_ceil(num_queries, num_threads); + + if (blockIdx.x >= num_query_blocks) { + return; + } + + int cur_query_id = blockIdx.x * num_threads + threadIdx.x; + + if (cur_query_id >= num_queries) { + return; + } + + // Update input_tokens + input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id]; + + int seq_len = seq_lens_ptr[cur_query_id]; + int next_seq_len = seq_len + 1; + int next_input_pos = next_seq_len - 1; + + // Update seq_lens + seq_lens_ptr[cur_query_id] = next_seq_len; + // Update input_positions + input_positions_ptr[cur_query_id] = next_input_pos; + + int const* seq_block_tables_ptr = + block_tables_ptr + block_tables_stride * cur_query_id; + + int block_index = next_input_pos / block_size; + int block_offset = next_input_pos % block_size; + + int slot_num = seq_block_tables_ptr[block_index] * block_size + block_offset; + // Update slot_mapping + slot_mapping_ptr[cur_query_id] = slot_num; +} + +inline void verify_tensor(std::string const& name, torch::Tensor& t, + int64_t const size_0, int64_t const size_1, + c10::ScalarType const type) { + bool size_0_cond = true; + if (size_0 != -1) { + size_0_cond = t.size(0) == size_0; + } + + bool size_1_cond = true; + if (size_1 != -1) { + size_1_cond = t.size(1) == size_1; + } + + bool is_contiguous = t.is_contiguous(); + bool same_type = t.dtype() == type; + + bool pass = size_0_cond && size_1_cond && is_contiguous && same_type; + if (!pass) { + TORCH_CHECK(false, "tensor: name = ", name, ", shape = ", t.sizes(), + " is_cont = ", t.is_contiguous(), ", type = ", t.dtype(), + " is not as expected: shape = [", size_0, ", ", size_1, + "], type = ", type); + } +} + +void advance_step(int num_seqs, int num_queries, int block_size, + torch::Tensor& input_tokens, // type: long + torch::Tensor& sampled_token_ids, // type: long + torch::Tensor& input_positions, // type: long + torch::Tensor& seq_lens, // type: int + torch::Tensor& slot_mapping, // type: long + torch::Tensor& block_tables) { // type: int + + if (logging) { + printf("advance_step:\n"); + printf(" num_seqs = %d\n", num_seqs); + printf(" num_queries = %d\n", num_queries); + printf(" block_size = %d\n", block_size); + } + // Verify all tensors + verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong); + verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1, + at::kLong); + verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong); + verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt); + verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kLong); + verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt); + + int dev = sampled_token_ids.get_device(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev); + + int blocks; + cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev); + + advance_step_kernel<<>>( + num_seqs, num_queries, block_size, + reinterpret_cast(input_tokens.data_ptr()), + reinterpret_cast(sampled_token_ids.data_ptr()), + reinterpret_cast(input_positions.data_ptr()), + reinterpret_cast(seq_lens.data_ptr()), + reinterpret_cast(slot_mapping.data_ptr()), + reinterpret_cast(block_tables.data_ptr()), + block_tables.stride(0)); +} + +} // namespace prepare_inputs + +void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size, + torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids, + torch::Tensor& input_positions, torch::Tensor& seq_lens, + torch::Tensor& slot_mapping, torch::Tensor& block_tables) { + prepare_inputs::advance_step(num_seqs, num_queries, block_size, input_tokens, + sampled_token_ids, input_positions, seq_lens, + slot_mapping, block_tables); +} \ No newline at end of file diff --git a/csrc/prepare_inputs/advance_step.cuh b/csrc/prepare_inputs/advance_step.cuh new file mode 100644 index 0000000000000..f21574681b1ab --- /dev/null +++ b/csrc/prepare_inputs/advance_step.cuh @@ -0,0 +1,19 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +namespace prepare_inputs { + +static constexpr int max_threads = 256; +static constexpr bool logging = false; + +constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } + +} // namespace prepare_inputs diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 9dc7cefc404ca..ff9875e0e17a3 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -72,6 +72,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("gelu_quick(Tensor! out, Tensor input) -> ()"); ops.impl("gelu_quick", torch::kCUDA, &gelu_quick); + // prepare_inputs advance_step + ops.def("advance_step", &advance_step); + ops.impl("advance_step", torch::kCUDA, &advance_step); + // Layernorm // Apply Root Mean Square (RMS) Normalization to the input tensor. ops.def( diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index 34a6c9a393a58..da72f6d503c11 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -227,6 +227,7 @@ def get_output_from_llm_generator( maybe_assert_ngram_worker(llm) outputs = llm.generate(prompts, sampling_params, use_tqdm=True) + token_ids = [output.outputs[0].token_ids for output in outputs] tokens = [output.outputs[0].text for output in outputs] diff --git a/tests/spec_decode/test_multi_step_worker.py b/tests/spec_decode/test_multi_step_worker.py index 9832d4f267e8a..442e40f07f0bb 100644 --- a/tests/spec_decode/test_multi_step_worker.py +++ b/tests/spec_decode/test_multi_step_worker.py @@ -642,3 +642,51 @@ def test_draft_proposals_mixed_k(): assert proposals.proposal_lens.tolist() == [ k for _ in range(expected_num_proposal_seqs - 1) ] + [0 for _ in range(expected_num_no_proposal_seqs)] + [k] + + +@torch.inference_mode() +def test_use_draft_model_runner_advance_step(): + """Verify that draft model runner triggers advance step + when applicable. + """ + seed = 100 + model_name = 'JackFram/llama-68m' + + k = 5 + batch_size = 32 + block_size = 32 + num_gpu_blocks = 2048 // block_size + worker = create_worker( + MultiStepWorker, + model_name, + block_size, + num_gpu_blocks, + seed, + model_runner_cls=TP1DraftModelRunner, + ) + + # Mock "_gpu_advance_step" to raise an exception when called. + exception_secret = "artificial stop" + worker.model_runner._gpu_advance_step = MagicMock() + worker.model_runner._gpu_advance_step.side_effect = ValueError( + exception_secret) + + seq_group_metadata_list, _, _ = create_batch(batch_size, k) + + # Fallback (should not call) when num_steps=1. + execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k, + num_steps=1) + worker.execute_model(execute_model_req=execute_model_req) + + # Expect exception if _gpu_advance_step is called. + execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k, + num_steps=k) + + with pytest.raises(ValueError, match=exception_secret): + worker.execute_model(execute_model_req=execute_model_req) + call_args_list = worker.model_runner._gpu_advance_step.call_args_list + assert len(call_args_list) == 1 diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 4ca67224a91b8..143957f7b65f0 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -166,6 +166,18 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon) +def advance_step(num_seqs: int, num_queries: int, block_size: int, + input_tokens: torch.Tensor, sampled_token_ids: torch.Tensor, + input_positions: torch.Tensor, seq_lens: torch.Tensor, + slot_mapping: torch.Tensor, + block_tables: torch.Tensor) -> None: + """Advance a step on GPU for existing inputs for a multi-step runner""" + return torch.ops._C.advance_step(num_seqs, num_queries, block_size, + input_tokens, sampled_token_ids, + input_positions, seq_lens, slot_mapping, + block_tables) + + # quantization ops # awq def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 6d00ea64f7cb8..5c376797a054f 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -47,6 +47,32 @@ def __init__(self): # speculative decoding. self.include_gpu_probs_tensor = False + def _init_sampling_tensors( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ): + """The goal here is to reuse sampling tensors between similar decode + runs. This is possible because sampling logic does not change between + decodes of the same sequences. + """ + _, vocab_size = logits.shape + + # First free any existing stored sampling tensors. + # This is necessary because some sampling tensors may + # have pinned memory. + self._sampling_tensors = None + + # Initialize new sampling tensors + (sampling_tensors, do_penalties, do_top_p_top_k, + do_min_p) = SamplingTensors.from_sampling_metadata( + sampling_metadata, vocab_size, logits.device, logits.dtype) + + self._sampling_tensors = sampling_tensors + self._do_penalties = do_penalties + self._do_top_p_top_k = do_top_p_top_k + self._do_min_p = do_min_p + def forward( self, logits: torch.Tensor, @@ -60,12 +86,23 @@ def forward( assert logits is not None _, vocab_size = logits.shape - logits = _apply_min_tokens_penalty(logits, sampling_metadata) - # Prepare sampling tensors with pinned memory to avoid blocking. - (sampling_tensors, do_penalties, do_top_p_top_k, - do_min_p) = SamplingTensors.from_sampling_metadata( - sampling_metadata, vocab_size, logits.device, logits.dtype) + if not sampling_metadata.reuse_sampling_tensors: + self._init_sampling_tensors(logits, sampling_metadata) + elif self._do_penalties: + # In this case, the sampling tensors logic depends on + # "output_tokens" of a sequence. As a result, we cannot + # reuse sampling tensors, since "output_tokens" changes + # between decode runs. + self._init_sampling_tensors(logits, sampling_metadata) + + assert self._sampling_tensors is not None + sampling_tensors = self._sampling_tensors + do_penalties = self._do_penalties + do_top_p_top_k = self._do_top_p_top_k + do_min_p = self._do_min_p + + logits = _apply_min_tokens_penalty(logits, sampling_metadata) # Apply presence and frequency penalties. if do_penalties: @@ -77,7 +114,7 @@ def forward( # Apply temperature scaling. # Use in-place division to avoid creating a new tensor. - logits.div_(sampling_tensors.temperatures.unsqueeze_(dim=1)) + logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1)) if do_top_p_top_k: logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps, @@ -109,13 +146,19 @@ def forward( on_device_tensors = None # Get the logprobs query results. - prompt_logprobs, sample_logprobs = _get_logprobs( - logprobs, sampling_metadata, sample_results) - return _build_sampler_output(sample_results, - sampling_metadata, - prompt_logprobs, - sample_logprobs, - on_device_tensors=on_device_tensors) + prompt_logprobs = None + sample_logprobs = None + if not sampling_metadata.skip_sampler_cpu_output: + prompt_logprobs, sample_logprobs = _get_logprobs( + logprobs, sampling_metadata, sample_results) + + return _build_sampler_output( + sample_results, + sampling_metadata, + prompt_logprobs, + sample_logprobs, + on_device_tensors=on_device_tensors, + skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output) @property def _should_modify_greedy_probs_inplace(self) -> bool: @@ -535,24 +578,29 @@ def _sample_with_torch( # GPU<->CPU sync happens in the loop below. # This also converts the sample output to Python objects. - for sampling_type in SamplingType: - if sampling_type not in sample_metadata: - continue - (seq_group_id, seq_groups) = sample_metadata[sampling_type] - if sampling_type == SamplingType.GREEDY: - sample_results = _greedy_sample(seq_groups, greedy_samples) - elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): - sample_results = _random_sample(seq_groups, - multinomial_samples[sampling_type]) - elif sampling_type == SamplingType.BEAM: - sample_results = _beam_search_sample(seq_groups, - beam_search_logprobs) - sample_results_dict.update(zip(seq_group_id, sample_results)) + if not sampling_metadata.skip_sampler_cpu_output: + for sampling_type in SamplingType: + if sampling_type not in sample_metadata: + continue + (seq_group_id, seq_groups) = sample_metadata[sampling_type] + if sampling_type == SamplingType.GREEDY: + sample_results = _greedy_sample(seq_groups, greedy_samples) + elif sampling_type in (SamplingType.RANDOM, + SamplingType.RANDOM_SEED): + sample_results = _random_sample( + seq_groups, multinomial_samples[sampling_type]) + elif sampling_type == SamplingType.BEAM: + sample_results = _beam_search_sample(seq_groups, + beam_search_logprobs) + sample_results_dict.update(zip(seq_group_id, sample_results)) + + sample_results = [ + sample_results_dict.get(i, ([], [])) + for i in range(len(sampling_metadata.seq_groups)) + ] + else: + sample_results = [] - sample_results = [ - sample_results_dict.get(i, ([], [])) - for i in range(len(sampling_metadata.seq_groups)) - ] return sample_results, sampled_token_ids_tensor @@ -997,10 +1045,11 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor, def _build_sampler_output( sample_results: SampleResultType, sampling_metadata: SamplingMetadata, - prompt_logprobs: List[Optional[PromptLogprobs]], - sample_logprobs: List[SampleLogprobs], + prompt_logprobs: Optional[List[Optional[PromptLogprobs]]], + sample_logprobs: Optional[List[SampleLogprobs]], on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]], + skip_sampler_cpu_output: bool = False, ) -> SamplerOutput: """Construct Python objects with the output of sampling. @@ -1010,22 +1059,26 @@ def _build_sampler_output( allows post-processing without copies to CPU/serialization, e.g. in speculative decoding rejection sampling. """ - sampler_output: List[CompletionSequenceGroupOutput] = [] - for (seq_group, sample_result, group_prompt_logprobs, - group_sample_logprobs) in zip(sampling_metadata.seq_groups, - sample_results, prompt_logprobs, - sample_logprobs): - seq_ids = seq_group.seq_ids - next_token_ids, parent_ids = sample_result - seq_outputs: List[SequenceOutput] = [] - for parent_id, next_token_id, logprobs in zip(parent_ids, - next_token_ids, - group_sample_logprobs): - seq_outputs.append( - SequenceOutput(seq_ids[parent_id], next_token_id, logprobs)) - sampler_output.append( - CompletionSequenceGroupOutput(seq_outputs, group_prompt_logprobs)) + if not skip_sampler_cpu_output: + assert prompt_logprobs is not None + assert sample_logprobs is not None + + for (seq_group, sample_result, group_prompt_logprobs, + group_sample_logprobs) in zip(sampling_metadata.seq_groups, + sample_results, prompt_logprobs, + sample_logprobs): + seq_ids = seq_group.seq_ids + next_token_ids, parent_ids = sample_result + seq_outputs: List[SequenceOutput] = [] + for parent_id, next_token_id, logprobs in zip( + parent_ids, next_token_ids, group_sample_logprobs): + seq_outputs.append( + SequenceOutput(seq_ids[parent_id], next_token_id, + logprobs)) + sampler_output.append( + CompletionSequenceGroupOutput(seq_outputs, + group_prompt_logprobs)) # If not specified, store None values in SamplerOutput. if on_device_tensors is not None: diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index c346cd0562867..29b077cf6d912 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -87,6 +87,12 @@ def sample(logits): The first tuple is [1, 2] (sampled index within original logit), and the second tuple is [0, 1] (sampled index within pruned logit). num_prompts: Number of prompt sequence groups in seq_groups. + skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU + serialization of token outputs. + reuse_sampling_tensors: Indicates if we want to reuse sampling + tensors that are part of the sampler forward pass. Currently, + it is mainly used for multi-step decode. + """ def __init__( @@ -95,11 +101,15 @@ def __init__( selected_token_indices: torch.Tensor, categorized_sample_indices: Dict[SamplingType, torch.Tensor], num_prompts: int, + skip_sampler_cpu_output: bool = False, + reuse_sampling_tensors: bool = False, ) -> None: self.seq_groups = seq_groups self.selected_token_indices = selected_token_indices self.categorized_sample_indices = categorized_sample_indices self.num_prompts = num_prompts + self.skip_sampler_cpu_output = skip_sampler_cpu_output + self.reuse_sampling_tensors = reuse_sampling_tensors @staticmethod def prepare( diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 90bba96ee8acb..3cb7ec58da4c1 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -2,17 +2,22 @@ import torch +from vllm import _custom_ops as ops +from vllm.attention.backends.flash_attn import FlashAttentionMetadata from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) from vllm.logger import init_logger -from vllm.sequence import (IntermediateTensors, SamplerOutput, - SequenceGroupMetadata) +from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, + SamplerOutput) from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata, ModelRunner) logger = init_logger(__name__) +debug_advance_input = False +enable_gpu_advance_step = True + class TP1DraftModelRunner(ModelRunner): """Specialized model runner for speculative decoding draft model. @@ -21,18 +26,9 @@ class TP1DraftModelRunner(ModelRunner): we could get rid of most CPU-GPU synchronization and data transfer overheads by keeping model input and output tensors on GPU all the time. - This runner is still under development so there's no performance gain - at this moment. Currently we adopt a temporary solution that caches the - seq_group_metadata_list for multi-step execution, so that we can - leverage existing prepare_model_input to be compatible with the current - execution flow, but we plan to remove this cache and avoid calling - prepare_model_input in execute_model at all. - - The detail development plan includes: - 1. Use "update_model_input" to update existing model_input without - creating a new one. - 2. Improve the performance of "update_model_input" with a GPU kernel. - 3. Support TP > 1 (this requires some designs because we do not expect + TODOs: + 1. Currently supports only flash-attn, add support for other attn_backends. + 2. Support TP > 1 (this requires some designs because we do not expect any broadcasting inside execute_model). """ @@ -71,51 +67,156 @@ def __init__( return_hidden_states=return_hidden_states, ) - # TODO: Remove this cache when we are able to update model_input - # directly in advance_step. - self.cached_seq_group_metadata_list: Optional[ - List[SequenceGroupMetadata]] = None + def _update_flash_attn_metadata(self, attn_metadata, num_seqs, + num_queries): + assert isinstance(attn_metadata, FlashAttentionMetadata) - def prepare_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None - ) -> ModelInputForGPUWithSamplingMetadata: - """A temporary solution that caches the seq_group_metadata_list - for multi-step execution. - TODO: In-place update model_input and remove this function. - """ - self.cached_seq_group_metadata_list = seq_group_metadata_list - return super().prepare_model_input( - seq_group_metadata_list, - finished_requests_ids=finished_requests_ids) + if num_seqs != num_queries: + assert num_seqs > num_queries + assert attn_metadata.use_cuda_graph + + assert attn_metadata.num_prefills == 0 + assert attn_metadata.num_prefill_tokens == 0 + assert attn_metadata.num_decode_tokens == num_seqs + assert attn_metadata.slot_mapping.shape == (num_seqs, ) + + assert len(attn_metadata.seq_lens) == num_seqs + assert attn_metadata.seq_lens_tensor.shape == (num_seqs, ) + assert attn_metadata.max_query_len == 1 + assert attn_metadata.max_prefill_seq_len == 0 + assert attn_metadata.max_decode_seq_len == max(attn_metadata.seq_lens) + + assert attn_metadata.query_start_loc.shape == (num_queries + 1, ) + assert attn_metadata.seq_start_loc.shape == (num_seqs + 1, ) + + assert attn_metadata.context_lens_tensor.shape == (num_queries, ) + + assert attn_metadata.block_tables.shape[0] == num_seqs + + # Update query lengths. Note that we update only queries and not seqs, + # since tensors may be padded due to captured cuda graph batch size + for i in range(num_queries): + attn_metadata.seq_lens[i] += 1 + attn_metadata.max_decode_seq_len = max(attn_metadata.seq_lens) - def update_model_input( + def _update_sampling_metadata(self, sampling_metadata, num_seqs, + num_queries): + + assert sampling_metadata.num_prompts == 0 + assert len(sampling_metadata.seq_groups) == num_queries + assert sampling_metadata.selected_token_indices.shape == ( + num_queries, ) + # assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501 + + # Verify that all sequences are decodes + for i in range(num_queries): + seq_group = sampling_metadata.seq_groups[i] + + assert seq_group.is_prompt is False # No prompt + assert seq_group.prompt_logprob_indices == [] # No prompt + assert seq_group.sample_indices == [i] # Simple + assert seq_group.seq_len is None # Decode + assert seq_group.query_len is None # Decode + + def _gpu_advance_step( self, model_input: ModelInputForGPUWithSamplingMetadata, last_output: SamplerOutput ) -> ModelInputForGPUWithSamplingMetadata: - """Prepare the model inputs for the next step. - TODO: In-place update model_input instead of calling - prepare_model_input. + # Currently, we expect "decode mode" only + assert not model_input.is_prompt + + # Get num_seqs + num_seqs = len(model_input.seq_lens) + num_queries = len(model_input.query_lens) + + # Get output tokens GPU tensor + sampled_token_ids = last_output.sampled_token_ids + assert sampled_token_ids is not None + + # Update attn_metadata + attn_metadata = model_input.attn_metadata + assert isinstance(attn_metadata, FlashAttentionMetadata) + self._update_flash_attn_metadata(attn_metadata, num_seqs, num_queries) + + # Update GPU tensors + ops.advance_step(num_seqs=num_seqs, + num_queries=num_queries, + block_size=self.block_size, + input_tokens=model_input.input_tokens, + sampled_token_ids=sampled_token_ids, + input_positions=model_input.input_positions, + seq_lens=attn_metadata.seq_lens_tensor, + slot_mapping=attn_metadata.slot_mapping, + block_tables=attn_metadata.block_tables) + + # Update sampling_metadata + sampling_metadata = model_input.sampling_metadata + self._update_sampling_metadata(sampling_metadata, num_seqs, + num_queries) + + # Create new input + new_model_input = self._model_input_cls( + input_tokens=model_input.input_tokens, + input_positions=model_input.input_positions, + attn_metadata=attn_metadata, + seq_lens=attn_metadata.seq_lens, + query_lens=model_input.query_lens, + lora_mapping=model_input.lora_mapping, + lora_requests=model_input.lora_requests, + multi_modal_kwargs=model_input.multi_modal_kwargs, + sampling_metadata=model_input.sampling_metadata, + is_prompt=False, + ) + + # Ensure we skip CPU samples + assert new_model_input.sampling_metadata.skip_sampler_cpu_output is True + # We can reuse sampling tensors since every decode iteration is the same + new_model_input.sampling_metadata.reuse_sampling_tensors = True + + if debug_advance_input: + logger.debug("NEW INPUT: ") + logger.debug(" input_tokens = %s", new_model_input.input_tokens) + logger.debug(" input_positions = %s", + new_model_input.input_positions) + logger.debug(" seq_lens = %d", new_model_input.seq_lens) + logger.debug(" query_lens = %d", new_model_input.query_lens) + logger.debug(" attn_metadata:") + logger.debug(" seq_lens_tensor: %s", + attn_metadata.seq_lens_tensor) + logger.debug(" slot_mapping: %s", attn_metadata.slot_mapping) + logger.debug(" block_tables: %s", attn_metadata.block_tables) + + return new_model_input + + def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest): + """Determines if draft_model_runner GPU multi-step can be used. + Currently required conditions are: + 1. Only decodes + 2. Only flash-attn + 3. No LORA + 4. No prompt_adapter_config """ + if not enable_gpu_advance_step: + return False - # Append the output token to the sequence data. - assert self.cached_seq_group_metadata_list is not None - for seq_group_metadata, sequence_group_outputs in zip( - self.cached_seq_group_metadata_list, last_output.outputs): - seq_group_metadata.is_prompt = False + # We allow multi-step GPU only in decode mode + for seq_group in execute_model_req.seq_group_metadata_list: + if seq_group.is_prompt: + return False - for seq_output in sequence_group_outputs.samples: - seq = seq_group_metadata.seq_data[seq_output.parent_seq_id] + # TODO: Add support for other attn backends + if self.attn_backend.get_name() != "flash-attn": + return False - token_id = seq_output.output_token - token_logprob = seq_output.logprobs[token_id] + # TODO: Add support for LORA + if self.lora_config: + return False - seq.append_token_id(token_id, token_logprob.logprob) - seq.update_num_computed_tokens(1) + # TODO: Add soft-tuning prompt adapter support + if self.prompt_adapter_config: + return False - return self.prepare_model_input(self.cached_seq_group_metadata_list) + return True @torch.inference_mode() def execute_model( @@ -125,42 +226,86 @@ def execute_model( intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, ) -> Optional[List[SamplerOutput]]: - # Since we do not broadcast data inside execute_model anymore, - # we need to figure out the best way to support TP > 1 in this - # case, because we will at least need to broadcast the sampled - # tokens to all workers. - if not self.is_driver_worker: - raise ValueError("TP1DraftModelRunner only supports TP=1.") + """Executes num_steps forward passes with advacement of input tensors + on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions. - if self.lora_config: - assert model_input.lora_requests is not None - assert model_input.lora_mapping is not None - self.set_active_loras(model_input.lora_requests, - model_input.lora_mapping) + Optimizations used: + 1. Input tensors are updated on the GPU directly + 2. Skips GPU=>CPU serialization of sampler outputs (we don't need + them since we do batch expansion later that uses GPU outputs) + 3. Reuses sampling tensors (since we run only decodes and they have + a repeating sampling logic) + """ - if self.prompt_adapter_config: - assert model_input.prompt_adapter_requests is not None - assert model_input.prompt_adapter_mapping is not None - self.set_active_prompt_adapters( - model_input.prompt_adapter_requests, - model_input.prompt_adapter_mapping) + # When num_steps == 1, we execute the fallback here for the GPU + # advance_step, which runs prepare_inputs on CPU and for each spec + # iteration invokes this function only once + # (Look at multi-step-worker code) + is_fallback = num_steps == 1 + if not is_fallback: + # Since we do not broadcast data inside execute_model anymore, + # we need to figure out the best way to support TP > 1 in this + # case, because we will at least need to broadcast the sampled + # tokens to all workers. + if not self.is_driver_worker: + raise ValueError("TP1DraftModelRunner only supports TP=1.") + + # Sanity + if self.lora_config is not None: + raise ValueError("TP1DraftModelRunner has no support for LORA") + if self.prompt_adapter_config is not None: + raise ValueError("TP1DraftModelRunner has no support for " + "prompt_adapter_config") + if model_input.multi_modal_kwargs: + raise ValueError( + "TP1DraftModelRunner has no support for multi_modal_kwargs" + ) + else: + if self.lora_config: + assert model_input.lora_requests is not None + assert model_input.lora_mapping is not None + self.set_active_loras(model_input.lora_requests, + model_input.lora_mapping) + + if self.prompt_adapter_config: + assert model_input.prompt_adapter_requests is not None + assert model_input.prompt_adapter_mapping is not None + self.set_active_prompt_adapters( + model_input.prompt_adapter_requests, + model_input.prompt_adapter_mapping) + + # Detect exec mode + assert model_input.attn_metadata is not None + use_cuda_graph = False + if model_input.attn_metadata.num_prefills > 0: + # In this case, execute_model(..) was called directly + if num_steps > 1: + raise ValueError( + "execute_model(..) of draft_model_runner can be called " + "directly only with a single-step prefill") + else: + # We can skip CPU samples for spec token generation. + # (We do allow CPU samples for num_steps == 1 to support the + # fallback case, where supports_gpu_multi_step(..) does not pass) + model_input.sampling_metadata.skip_sampler_cpu_output = ( + not is_fallback) + + # Attn attr defines if we use cuda graphs + use_cuda_graph = model_input.attn_metadata.use_cuda_graph + + # Get model + if use_cuda_graph: + graph_batch_size = model_input.input_tokens.shape[0] + model_executable = (self.graph_runners[model_input.virtual_engine] + [graph_batch_size]) + else: + model_executable = self.model - virtual_engine = model_input.virtual_engine outputs: List[SamplerOutput] = [] for step in range(num_steps): - # Currently cuda graph is only supported by the decode phase. - assert model_input.attn_metadata is not None - prefill_meta = model_input.attn_metadata.prefill_metadata - decode_meta = model_input.attn_metadata.decode_metadata - if prefill_meta is None and decode_meta.use_cuda_graph: - assert model_input.input_tokens is not None - graph_batch_size = model_input.input_tokens.shape[0] - model_executable = ( - self.graph_runners[virtual_engine][graph_batch_size]) - else: - model_executable = self.model - multi_modal_kwargs = model_input.multi_modal_kwargs or {} + + # Run model hidden_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, @@ -181,8 +326,8 @@ def execute_model( sampling_metadata=model_input.sampling_metadata, )) - # Prepare the inputs for the next step. + # Prepare inputs for the next step if step != num_steps - 1: - model_input = self.update_model_input(model_input, outputs[-1]) + model_input = self._gpu_advance_step(model_input, outputs[-1]) return outputs diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 11e99882e3f0b..91689324557b5 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -67,14 +67,23 @@ def sampler_output( expanded_request, indices_of_seq_with_bonus_tokens =\ self._expand_execute_model_request( execute_model_req, seq_ids_with_bonus_token_in_last_step) + # Run model sample_len times. model_outputs: List[SamplerOutput] = [] - if isinstance(self.model_runner, TP1DraftModelRunner): + if isinstance( + self.model_runner, TP1DraftModelRunner + ) and self.model_runner.supports_gpu_multi_step(expanded_request): + # Here we run the draft_model_runner with multi-step prepare + # on the GPU directly expanded_request.num_steps = sample_len model_outputs = self.execute_model( execute_model_req=expanded_request) else: - # TODO: Remove this branch once DraftModelRunner supports TP>1. + # Here we run multi-step directly, with every step prepared + # on the CPU. + # TODO: Remove this branch once DraftModelRunner supports TP>1 + # and other restrictions that are part of DraftModelRunner's + # supports_gpu_multi_step(..) for _ in range(sample_len): model_output: List[SamplerOutput] = super().execute_model( execute_model_req=expanded_request) @@ -171,7 +180,7 @@ def _filter_model_output( outputs=[ expanded_batch_output.outputs[i] for i in output_indices_to_retain - ], + ] if len(expanded_batch_output.outputs) > 0 else [], sampled_token_probs=( expanded_batch_output. sampled_token_probs[output_indices_to_retain] From b5241e41d9fef56a89dcfda367a7eff87a07e3f7 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Wed, 17 Jul 2024 21:38:35 -0400 Subject: [PATCH 02/13] [ Kernel ] FP8 Dynamic-Per-Token Quant Kernel (#6511) Co-authored-by: Varun Sundar Rabindranath --- csrc/ops.h | 10 ++- csrc/quantization/fp8/common.cu | 144 ++++++++++++++++++++++++++----- csrc/torch_bindings.cpp | 10 ++- tests/kernels/quant_utils.py | 56 ++++++++++++ tests/kernels/test_fp8_quant.py | 54 ++++++++++++ tests/kernels/test_int8_quant.py | 26 +++--- vllm/_custom_ops.py | 11 +++ 7 files changed, 271 insertions(+), 40 deletions(-) create mode 100644 tests/kernels/quant_utils.py create mode 100644 tests/kernels/test_fp8_quant.py diff --git a/csrc/ops.h b/csrc/ops.h index 1e94a9f45ef08..c0f924c09b515 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -128,12 +128,16 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit); -void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input, - torch::Tensor& scale); +void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input, + torch::Tensor const& scale); -void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input, +void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale); +void dynamic_per_token_scaled_fp8_quant(torch::Tensor& out, + torch::Tensor const& input, + torch::Tensor& scale); + void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, torch::Tensor sorted_token_ids, torch::Tensor experts_ids, diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu index 6120086d72df2..0938c0707679f 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/fp8/common.cu @@ -7,6 +7,8 @@ #include "cuda_compat.h" #include "dispatch_utils.h" +#include "../../reduction_utils.cuh" + namespace vllm { __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { @@ -88,25 +90,48 @@ typedef struct __align__(4) { float8x4_t; template -__global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out, - const scalar_t* __restrict__ input, - const float* __restrict__ scale, - int64_t num_elems) { - int tid = blockDim.x * blockIdx.x + threadIdx.x; +__device__ float thread_max_vec(scalar_t const* __restrict__ input, + int64_t const num_elems, int const tid, + int const step) { + // Vectorized input/output to better utilize memory bandwidth. + vec4_t const* vectorized_in = + reinterpret_cast const*>(input); - // Invert the scale so that we can use multiplications to avoid expensive - // division. - const float inverted_scale = 1.0f / (*scale); + int const num_vec_elems = num_elems >> 2; + float absmax_val = 0.0f; + +#pragma unroll 4 + for (int i = tid; i < num_vec_elems; i += step) { + vec4_t in_vec = vectorized_in[i]; + absmax_val = max(absmax_val, fabs(in_vec.x)); + absmax_val = max(absmax_val, fabs(in_vec.y)); + absmax_val = max(absmax_val, fabs(in_vec.z)); + absmax_val = max(absmax_val, fabs(in_vec.w)); + } + // Handle the remaining elements if num_elems is not divisible by 4 + for (int i = num_vec_elems * 4 + tid; i < num_elems; i += step) { + absmax_val = max(absmax_val, fabs(input[i])); + } + + return absmax_val; +} + +template +__device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out, + scalar_t const* __restrict__ input, + float const inverted_scale, + int64_t const num_elems, + int const tid, int const step) { // Vectorized input/output to better utilize memory bandwidth. - const vec4_t* vectorized_in = - reinterpret_cast*>(input); + vec4_t const* vectorized_in = + reinterpret_cast const*>(input); float8x4_t* vectorized_out = reinterpret_cast(out); - int num_vec_elems = num_elems >> 2; + int const num_vec_elems = num_elems >> 2; #pragma unroll 4 - for (int i = tid; i < num_vec_elems; i += blockDim.x * gridDim.x) { + for (int i = tid; i < num_vec_elems; i += step) { vec4_t in_vec = vectorized_in[i]; float8x4_t out_vec; @@ -118,17 +143,74 @@ __global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out, } // Handle the remaining elements if num_elems is not divisible by 4 - for (int i = num_vec_elems * 4 + tid; i < num_elems; - i += blockDim.x * gridDim.x) { + for (int i = num_vec_elems * 4 + tid; i < num_elems; i += step) { out[i] = scaled_fp8_conversion(input[i], inverted_scale); } } +template +__global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out, + const scalar_t* __restrict__ input, + const float* __restrict__ scale, + int64_t num_elems) { + int tid = blockDim.x * blockIdx.x + threadIdx.x; + + // Invert the scale so that we can use multiplications to avoid expensive + // division. + const float inverted_scale = 1.0f / (*scale); + + scaled_fp8_conversion_vec(out, input, inverted_scale, num_elems, tid, + blockDim.x * gridDim.x); +} + +template +__global__ void dynamic_per_token_scaled_fp8_quant_kernel( + c10::Float8_e4m3fn* __restrict__ out, float* __restrict__ scale, + scalar_t const* __restrict__ input, const int hidden_size) { + int const tid = threadIdx.x; + int const token_idx = blockIdx.x; + + scalar_t const* __restrict__ token_input = &input[token_idx * hidden_size]; + c10::Float8_e4m3fn* __restrict__ token_output = &out[token_idx * hidden_size]; + + // For vectorization, token_input and token_output pointers need to be + // aligned at 8-byte and 4-byte addresses respectively. + bool const can_vectorize = hidden_size % 4 == 0; + + float absmax_val = 0.0f; + if (can_vectorize) { + absmax_val = thread_max_vec(token_input, hidden_size, tid, blockDim.x); + } else { + for (int i = tid; i < hidden_size; i += blockDim.x) { + float const x = static_cast(token_input[i]); + absmax_val = max(absmax_val, fabs(x)); + } + } + + float const block_absmax_val_maybe = blockReduceMax(absmax_val); + __shared__ float block_absmax_val; + if (tid == 0) { + block_absmax_val = block_absmax_val_maybe; + scale[token_idx] = block_absmax_val / FP8_E4M3_MAX; + } + __syncthreads(); + + float const inverted_scale = FP8_E4M3_MAX / block_absmax_val; + if (can_vectorize) { + scaled_fp8_conversion_vec(token_output, token_input, inverted_scale, + hidden_size, tid, blockDim.x); + } else { + for (int i = tid; i < hidden_size; i += blockDim.x) { + token_output[i] = scaled_fp8_conversion(token_input[i], inverted_scale); + } + } +} + } // namespace vllm -void static_scaled_fp8_quant(torch::Tensor& out, // [..., d] - torch::Tensor& input, // [..., d] - torch::Tensor& scale) // [1] +void static_scaled_fp8_quant(torch::Tensor& out, // [..., d] + torch::Tensor const& input, // [..., d] + torch::Tensor const& scale) // [1] { int64_t num_tokens = input.numel() / input.size(-1); int64_t num_elems = input.numel(); @@ -144,9 +226,9 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d] }); } -void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d] - torch::Tensor& input, // [..., d] - torch::Tensor& scale) // [1] +void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d] + torch::Tensor const& input, // [..., d] + torch::Tensor& scale) // [1] { int64_t num_tokens = input.numel() / input.size(-1); int64_t num_elems = input.numel(); @@ -163,3 +245,25 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d] scale.data_ptr(), num_elems); }); } + +void dynamic_per_token_scaled_fp8_quant(torch::Tensor& out, // [..., d] + torch::Tensor const& input, // [..., d] + torch::Tensor& scales) { + TORCH_CHECK(input.is_contiguous()); + TORCH_CHECK(out.is_contiguous()); + + int const hidden_size = input.size(-1); + int const num_tokens = input.numel() / hidden_size; + dim3 const grid(num_tokens); + dim3 const block(std::min(hidden_size, 1024)); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel", [&] { + vllm::dynamic_per_token_scaled_fp8_quant_kernel + <<>>( + out.data_ptr(), scales.data_ptr(), + input.data_ptr(), hidden_size); + }); +} diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index ff9875e0e17a3..55ccc6f53b455 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -179,12 +179,20 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "static_scaled_fp8_quant(Tensor! out, Tensor input, Tensor scale) -> ()"); ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant); - // Compute FP8 quantized tensor and scaling factor. + // Compute dynamic-per-tensor FP8 quantized tensor and scaling factor. ops.def( "dynamic_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! scale) -> " "()"); ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant); + // Compute dynamic-per-token FP8 quantized tensor and scaling factor. + ops.def( + "dynamic_per_token_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! " + "scale) -> " + "()"); + ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA, + &dynamic_per_token_scaled_fp8_quant); + // Aligning the number of tokens to be processed by each expert such // that it is divisible by the block size. ops.def( diff --git a/tests/kernels/quant_utils.py b/tests/kernels/quant_utils.py new file mode 100644 index 0000000000000..a1513bdffe768 --- /dev/null +++ b/tests/kernels/quant_utils.py @@ -0,0 +1,56 @@ +from typing import Tuple, Union + +import torch + + +def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor: + return torch.as_tensor(x, dtype=torch.float32, device='cuda') + +def ref_dynamic_per_token_quant(x: torch.tensor, + quant_dtype: torch.dtype) \ + -> Tuple[torch.tensor, torch.tensor]: + + assert quant_dtype in [torch.int8, torch.float8_e4m3fn] + qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \ + else torch.finfo(quant_dtype) + qtype_max = as_float32_tensor(qtype_traits.max) + + # For fp8, in order to match the cuda kernel output, we have to do exactly + # the same operations as in the corresponding fp8 kernel to prevent + # rounding errors. + + # Compute scales + x_token_max, _ = x.abs().max(dim=-1) + x_token_max = as_float32_tensor(x_token_max) + scales = (x_token_max / qtype_max)[:, None] + + # Quant + iscales = (qtype_max / x_token_max)[:, None] + torch_out = as_float32_tensor(x) * iscales + torch_out = torch_out.round() if quant_dtype == torch.int8 else torch_out + torch_out = torch_out.clamp(qtype_traits.min, + qtype_traits.max).to(quant_dtype) + + return torch_out, scales + + +# The int8 version is very similar. Incorporate the int8 version, like in +# ref_dynamic_per_token_quant, when we have a dynamic_per_tensor int8 quant +# kernel +def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \ + -> Tuple[torch.tensor, torch.tensor]: + + fp8_traits = torch.finfo(torch.float8_e4m3fn) + fp8_max = as_float32_tensor(fp8_traits.max) + one = as_float32_tensor(1.0) + + # For fp8, in order to match the cuda kernel output, we have to do exactly + # the same operations as in the corresponding fp8 kernel to prevent + # rounding errors. + + x_max = as_float32_tensor(x.abs().max()) + ref_scale = x_max / fp8_max + ref_iscale = one / ref_scale + ref_out = (as_float32_tensor(x) * ref_iscale).clamp( + fp8_traits.min, fp8_traits.max).to(dtype=torch.float8_e4m3fn) + return ref_out, ref_scale diff --git a/tests/kernels/test_fp8_quant.py b/tests/kernels/test_fp8_quant.py new file mode 100644 index 0000000000000..6b555c8e242ad --- /dev/null +++ b/tests/kernels/test_fp8_quant.py @@ -0,0 +1,54 @@ +import pytest +import torch + +import vllm._custom_ops as ops +from tests.kernels.quant_utils import (ref_dynamic_per_tensor_fp8_quant, + ref_dynamic_per_token_quant) + +DTYPES = [torch.half, torch.bfloat16, torch.float] +HIDDEN_SIZES = [1, 2, 3, 4, 16, 67, 768, 2048, 5120, 5137, 8192, + 8193] # Arbitrary values for testing +HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases +NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing +SEEDS = [0] + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@torch.inference_mode() +def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int, + dtype: torch.dtype, seed: int) -> None: + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + x = torch.rand(num_tokens, hidden_size, dtype=dtype, + device="cuda") + 1e-6 # avoid nans + + ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.float8_e4m3fn) + ops_out, ops_scales = ops.dynamic_per_token_scaled_fp8_quant(x) + + assert torch.allclose(ref_scales, ops_scales) + assert torch.allclose(ref_out.to(dtype=torch.float32), + ops_out.to(dtype=torch.float32)) + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@torch.inference_mode() +def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int, + dtype: torch.dtype, seed: int) -> None: + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") + + ref_out, ref_scale = ref_dynamic_per_tensor_fp8_quant(x) + ops_out, ops_scale = ops.scaled_fp8_quant(x) + + assert torch.allclose(ref_scale, ops_scale) + assert torch.allclose(ref_out.to(dtype=torch.float32), + ops_out.to(dtype=torch.float32)) diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py index 0daf7439468aa..03acbf7968ff1 100644 --- a/tests/kernels/test_int8_quant.py +++ b/tests/kernels/test_int8_quant.py @@ -3,6 +3,8 @@ # ruff: noqa: F401 import vllm._C +from tests.kernels.quant_utils import ref_dynamic_per_token_quant +from vllm._custom_ops import scaled_int8_quant DTYPES = [torch.half, torch.bfloat16, torch.float] HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 5137, 8192, @@ -21,23 +23,16 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int) -> None: torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - int8_traits = torch.iinfo(torch.int8) x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 - x_token_max, _ = x.max(dim=1) - x_token_max = x_token_max.to(dtype=torch.float32) - scales = (x_token_max / float(127.0))[:, None].to(device="cuda", - dtype=torch.float32) - torch_out = (x / scales).round().clamp(int8_traits.min, - int8_traits.max).to(torch.int8) - - ops_out = torch.empty_like(x, dtype=torch.int8, device="cuda") - scales_out = torch.empty_like(scales, dtype=torch.float32, device="cuda") - torch.ops._C.dynamic_scaled_int8_quant(ops_out, x, scales_out) + # reference + ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.int8) + # kernel + ops_out, ops_scales = scaled_int8_quant(x) - assert torch.allclose(scales_out, scales) - assert torch.allclose(torch_out, ops_out, + assert torch.allclose(ops_scales, ref_scales) + assert torch.allclose(ops_out, ref_out, atol=1) # big atol to account for rounding errors @@ -55,12 +50,11 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int, int8_traits = torch.iinfo(torch.int8) x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 + scale = torch.tensor([scale], dtype=torch.float32, device="cuda") out1 = (x / scale).round().clamp(int8_traits.min, int8_traits.max).to(torch.int8) - out2 = torch.empty_like(x, dtype=torch.int8) - scale_argument = torch.tensor([scale], dtype=torch.float32, device="cuda") + out2, _ = scaled_int8_quant(x, scale) - torch.ops._C.static_scaled_int8_quant(out2, x, scale_argument) assert torch.allclose(out1, out2, atol=1) # big atol to account for rounding errors diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 143957f7b65f0..07646ae582a28 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -335,6 +335,17 @@ def scaled_fp8_quant( return output, scale +def dynamic_per_token_scaled_fp8_quant( + input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + + output = torch.empty_like(input, dtype=torch.float8_e4m3fn) + scales = torch.empty((input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.float32) + torch.ops._C.dynamic_per_token_scaled_fp8_quant(output, input, scales) + return output, scales + + # int8 def scaled_int8_quant( input: torch.Tensor, From b5af8c223c3d70557e7d73ba3c4c2e9d56fc9694 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Wed, 17 Jul 2024 19:26:04 -0700 Subject: [PATCH 03/13] [Model] Pipeline parallel support for Mixtral (#6516) --- tests/distributed/test_pipeline_parallel.py | 17 ++++-- vllm/config.py | 1 + vllm/model_executor/models/mixtral.py | 61 ++++++++++++++++----- 3 files changed, 60 insertions(+), 19 deletions(-) diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 123a77e14ad74..d7e640ce96995 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -1,4 +1,5 @@ import pytest +from transformers import AutoTokenizer from ..utils import RemoteOpenAIServer @@ -12,6 +13,8 @@ (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B"), ]) def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME): + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + pp_args = [ # use half precision for speed and memory savings in CI environment "--dtype", @@ -34,7 +37,7 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME): "--dtype", "bfloat16", "--tensor-parallel-size", - str(max(TP_SIZE, 2)), # use at least TP_SIZE=2 to hold the model + str(max(TP_SIZE, 2)), # We only use 2 GPUs in the CI. "--distributed-executor-backend", "mp", ] @@ -45,8 +48,10 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME): pp_args.append("--enforce-eager") tp_args.append("--enforce-eager") + prompt = "Hello, my name is" + token_ids = tokenizer(prompt)["input_ids"] results = [] - for args in [pp_args, tp_args]: + for args in (pp_args, tp_args): with RemoteOpenAIServer(MODEL_NAME, args) as server: client = server.get_client() @@ -62,7 +67,7 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME): # test with text prompt completion = client.completions.create(model=MODEL_NAME, - prompt="Hello, my name is", + prompt=prompt, max_tokens=5, temperature=0.0) @@ -76,7 +81,7 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME): # test using token IDs completion = client.completions.create( model=MODEL_NAME, - prompt=[0, 0, 0, 0, 0], + prompt=token_ids, max_tokens=5, temperature=0.0, ) @@ -91,7 +96,7 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME): # test simple list batch = client.completions.create( model=MODEL_NAME, - prompt=["Hello, my name is", "Hello, my name is"], + prompt=[prompt, prompt], max_tokens=5, temperature=0.0, ) @@ -105,7 +110,7 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME): # test streaming batch = client.completions.create( model=MODEL_NAME, - prompt=["Hello, my name is", "Hello, my name is"], + prompt=[prompt, prompt], max_tokens=5, temperature=0.0, stream=True, diff --git a/vllm/config.py b/vllm/config.py index de7bb3943a45f..a20e830955671 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -34,6 +34,7 @@ "MistralForCausalLM", "Phi3ForCausalLM", "GPT2LMHeadModel", + "MixtralForCausalLM", ] diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index e739df87cf96a..28dbcb30bdf55 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -29,7 +29,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (QKVParallelLinear, @@ -48,6 +48,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput from .interfaces import SupportsLoRA +from .utils import is_pp_missing_parameter, make_layers class MixtralMoE(nn.Module): @@ -255,12 +256,11 @@ def __init__( config.hidden_size, org_num_embeddings=config.vocab_size, ) - self.layers = nn.ModuleList([ - MixtralDecoderLayer(config, - cache_config, - quant_config=quant_config) - for _ in range(config.num_hidden_layers) - ]) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, lambda: MixtralDecoderLayer( + config, cache_config, quant_config=quant_config)) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( @@ -269,14 +269,25 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) - residual = None - for i in range(len(self.layers)): + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer(positions, hidden_states, - kv_caches[i], attn_metadata, - residual) + kv_caches[i - self.start_layer], + attn_metadata, residual) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -347,7 +358,7 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, @@ -356,6 +367,20 @@ def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata) return logits + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + def sample( self, logits: Optional[torch.Tensor], @@ -392,6 +417,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -402,6 +431,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if weight_name not in name: continue name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, @@ -414,6 +446,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue # Remapping the name of FP8 kv-scale. name = maybe_remap_kv_scale_name(name, params_dict) if name is None: From 18fecc3559cad615fd8565d1a001557ec355f285 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Wed, 17 Jul 2024 23:18:13 -0400 Subject: [PATCH 04/13] [ Kernel ] Fp8 Channelwise Weight Support (#6487) --- vllm/config.py | 3 +- .../compressed_tensors/compressed_tensors.py | 18 +++-- .../schemes/compressed_tensors_w8a8_fp8.py | 80 ++++++++++++------- .../quantization/compressed_tensors/utils.py | 10 +++ 4 files changed, 76 insertions(+), 35 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index a20e830955671..c87974d0df16d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -238,7 +238,8 @@ def _verify_quantization(self) -> None: f"{self.quantization} quantization is currently not " f"supported in ROCm.") if (self.quantization - not in ("fp8", "marlin", "gptq_marlin_24", "gptq_marlin")): + not in ("fp8", "marlin", "gptq_marlin_24", "gptq_marlin", + "compressed_tensors")): logger.warning( "%s quantization is not fully " "optimized yet. The speed can be slower than " diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 524b4c894b9b5..1424c620ae675 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -13,7 +13,8 @@ CompressedTensorsWNA16) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( CompressionFormat, QuantizationArgs, QuantizationStrategy, - QuantizationType, find_first_name_or_class_match) + QuantizationType, find_first_name_or_class_match, + is_activation_quantization_format) from vllm.platforms import current_platform @@ -132,10 +133,11 @@ def _is_fp8_w8a8(self, weight_quant: BaseModel, # Confirm weight scheme is supported. is_symmetric_weight = weight_quant.symmetric is_static_weight = not weight_quant.dynamic - is_per_tensor_weight = ( - weight_quant.strategy == QuantizationStrategy.TENSOR) + is_per_tensor_or_channel_weight = (weight_quant.strategy in [ + QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL + ]) if not (is_symmetric_weight and is_static_weight - and is_per_tensor_weight): + and is_per_tensor_or_channel_weight): return False # Dynamic quantization is always supported if weights supported. @@ -167,6 +169,7 @@ def _is_wNa16_group_channel(self, weight_quant: BaseModel, def _get_schema(self, weight_quant: BaseModel, input_quant: BaseModel) -> "CompressedTensorsScheme": + # Detect If Mixed Precision if self._is_wNa16_group_channel(weight_quant, input_quant): self._check_gptq_and_marlin_can_run() if (self.quant_format == CompressionFormat.marlin_24.value @@ -182,11 +185,12 @@ def _get_schema(self, weight_quant: BaseModel, strategy=weight_quant.strategy, group_size=weight_quant.group_size) - if (self.quant_format == CompressionFormat.int_quantized.value or - self.quant_format == CompressionFormat.float_quantized.value): + # Detect If Activation Quantization. + if is_activation_quantization_format(self.quant_format): if self._is_fp8_w8a8(weight_quant, input_quant): return CompressedTensorsW8A8Fp8( - input_dynamic=input_quant.dynamic) + strategy=weight_quant.strategy, + is_static_input_scheme=(not input_quant.dynamic)) if self._is_static_tensor_w8a8(weight_quant, input_quant): return CompressedTensorsW8A8Int8( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index b93425fb2d629..f1ca9510d92aa 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -1,11 +1,15 @@ from typing import Callable, List, Optional import torch +from torch.nn import Parameter from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + QuantizationStrategy) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - apply_fp8_linear, create_per_tensor_scale_param, cutlass_fp8_supported, + apply_fp8_linear, create_per_channel_scale_param, + create_per_tensor_scale_param, cutlass_fp8_supported, requantize_with_max_scale) from vllm.model_executor.utils import set_weight_attrs @@ -14,39 +18,56 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): - def __init__(self, input_dynamic: bool): - self.input_dynamic = input_dynamic + def __init__(self, strategy: str, is_static_input_scheme: bool): + self.strategy = strategy + self.is_static_input_scheme = is_static_input_scheme self.cutlass_fp8_supported = cutlass_fp8_supported() - # W8A8-Fp8 kernels support only per-tensor and per-channel cases. - # So if we have a fused module (QKV, MLP) with per tensor scales (thus N - # scales being passed to the kernel), we requantize with a single scale. + # On Lovelace, fail for now if channelwise. + # TODO: (@tms) fallback + if (not self.cutlass_fp8_supported + and self.strategy == QuantizationStrategy.CHANNEL): + raise ValueError( + "Channelwise fp8 quantization requires vLLM's custom " + "cutlass kernels, which are not supported on your device." + "Consider quantizing with per tensor scales or upgrading " + "to Hopper.") + def process_weights_after_loading(self, layer) -> None: - # Dequant -> Quant with max scale. - max_w_scale, weight = requantize_with_max_scale( - weight=layer.weight, - weight_scale=layer.weight_scale, - logical_widths=layer.logical_widths, - ) - - # Update layer with new values. - layer.weight = torch.nn.Parameter(weight.t(), requires_grad=False) - layer.weight_scale = torch.nn.Parameter(max_w_scale, - requires_grad=False) - if self.input_dynamic: - layer.input_scale = None + # If per tensor, when we have a fused module (e.g. QKV) with per + # tensor scales (thus N scales being passed to the kernel), + # requantize so we can always run per tensor + if self.strategy == QuantizationStrategy.TENSOR: + max_w_scale, weight = requantize_with_max_scale( + weight=layer.weight, + weight_scale=layer.weight_scale, + logical_widths=layer.logical_widths, + ) + + layer.weight = Parameter(weight.t(), requires_grad=False) + layer.weight_scale = Parameter(max_w_scale, requires_grad=False) + + # If channelwise, scales are already lined up, so just transpose. + elif self.strategy == QuantizationStrategy.CHANNEL: + assert self.cutlass_fp8_supported + weight = layer.weight + layer.weight = Parameter(weight.t(), requires_grad=False) + + else: + raise ValueError(f"Unknown quantization strategy {self.strategy}") + + # INPUT SCALE + if self.is_static_input_scheme: + layer.input_scale = Parameter(layer.input_scale.max(), + requires_grad=False) else: - layer.input_scale = torch.nn.Parameter(layer.input_scale.max(), - requires_grad=False) + layer.input_scale = None def create_weights(self, layer: torch.nn.Module, output_partition_sizes: List[int], input_size_per_partition: int, params_dtype: torch.dtype, weight_loader: Callable, **kwargs): - - del params_dtype - output_size_per_partition = sum(output_partition_sizes) layer.logical_widths = output_partition_sizes @@ -63,12 +84,17 @@ def create_weights(self, layer: torch.nn.Module, }) # WEIGHT SCALE - weight_scale = create_per_tensor_scale_param( - output_partition_sizes, weight_loader=weight_loader) + if self.strategy == QuantizationStrategy.CHANNEL: + weight_scale = create_per_channel_scale_param( + output_partition_sizes, weight_loader=weight_loader) + else: + assert self.strategy == QuantizationStrategy.TENSOR + weight_scale = create_per_tensor_scale_param( + output_partition_sizes, weight_loader=weight_loader) layer.register_parameter("weight_scale", weight_scale) # INPUT SCALE - if not self.input_dynamic: + if self.is_static_input_scheme: input_scale = create_per_tensor_scale_param( output_partition_sizes, weight_loader=weight_loader) layer.register_parameter("input_scale", input_scale) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py index 5b44c215535b5..25db308753eee 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py @@ -9,6 +9,7 @@ class CompressionFormat(Enum): dense = "dense" sparse_bitmask = "sparse-bitmask" + naive_quantized = "naive-quantized" float_quantized = "float-quantized" int_quantized = "int-quantized" pack_quantized = "pack-quantized" @@ -76,6 +77,15 @@ class QuantizationArgs(BaseModel): ) +def is_activation_quantization_format(format: str) -> bool: + _ACTIVATION_QUANTIZATION_FORMATS = [ + CompressionFormat.naive_quantized.value, + CompressionFormat.int_quantized.value, + CompressionFormat.float_quantized.value + ] + return format in _ACTIVATION_QUANTIZATION_FORMATS + + def find_first_name_or_class_match( name: str, module: Module, From 1c27d25fb57285d2c5bb6d17d70549ab4b8f45a7 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 17 Jul 2024 20:54:35 -0700 Subject: [PATCH 05/13] [core][model] yet another cpu offload implementation (#6496) Co-authored-by: Michael Goin --- .buildkite/test-pipeline.yaml | 1 + examples/cpu_offload.py | 22 +++++++++ vllm/config.py | 2 + vllm/engine/arg_utils.py | 24 +++++++++- vllm/entrypoints/llm.py | 6 +++ vllm/model_executor/models/utils.py | 73 +++++++++++++++++++++++++++-- vllm/worker/model_runner.py | 4 ++ 7 files changed, 128 insertions(+), 4 deletions(-) create mode 100644 examples/cpu_offload.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 445d74d6d9bbe..00fa86b4c448f 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -140,6 +140,7 @@ steps: # install tensorizer for tensorize_vllm_model.py - pip install awscli tensorizer - python3 offline_inference.py + - python3 cpu_offload.py - python3 offline_inference_with_prefix.py - python3 llm_engine_example.py - python3 llava_example.py diff --git a/examples/cpu_offload.py b/examples/cpu_offload.py new file mode 100644 index 0000000000000..b152e5bc37e6d --- /dev/null +++ b/examples/cpu_offload.py @@ -0,0 +1,22 @@ +from vllm import LLM, SamplingParams + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# Create a sampling params object. +sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + +# Create an LLM. +llm = LLM(model="meta-llama/Llama-2-13b-chat-hf", cpu_offload_gb=10) +# Generate texts from the prompts. The output is a list of RequestOutput objects +# that contain the prompt, generated text, and other information. +outputs = llm.generate(prompts, sampling_params) +# Print the outputs. +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/vllm/config.py b/vllm/config.py index c87974d0df16d..419118375e704 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -433,6 +433,7 @@ def __init__( num_gpu_blocks_override: Optional[int] = None, sliding_window: Optional[int] = None, enable_prefix_caching: bool = False, + cpu_offload_gb: float = 0, ) -> None: self.block_size = block_size self.gpu_memory_utilization = gpu_memory_utilization @@ -441,6 +442,7 @@ def __init__( self.cache_dtype = cache_dtype self.sliding_window = sliding_window self.enable_prefix_caching = enable_prefix_caching + self.cpu_offload_gb = cpu_offload_gb self._verify_args() self._verify_cache_dtype() self._verify_prefix_caching() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index b972573c0258e..28ae3448fb495 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -45,6 +45,7 @@ class EngineArgs: disable_sliding_window: bool = False use_v2_block_manager: bool = False swap_space: int = 4 # GiB + cpu_offload_gb: int = 0 # GiB gpu_memory_utilization: float = 0.90 max_num_batched_tokens: Optional[int] = None max_num_seqs: int = 256 @@ -303,6 +304,20 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: type=int, default=EngineArgs.swap_space, help='CPU swap space size (GiB) per GPU.') + parser.add_argument( + '--cpu-offload-gb', + type=float, + default=0, + help='The space in GiB to offload to CPU, per GPU. ' + 'Default is 0, which means no offloading. Intuitively, ' + 'this argument can be seen as a virtual way to increase ' + 'the GPU memory size. For example, if you have one 24 GB ' + 'GPU and set this to 10, virtually you can think of it as ' + 'a 34 GB GPU. Then you can load a 13B model with BF16 weight,' + 'which requires at least 26GB GPU memory. Note that this ' + 'requires fast CPU-GPU interconnect, as part of the model is' + 'loaded from CPU memory to GPU memory on the fly in each ' + 'model forward pass.') parser.add_argument( '--gpu-memory-utilization', type=float, @@ -633,6 +648,11 @@ def create_engine_config(self, ) -> EngineConfig: raise ValueError( "BitsAndBytes load format and QLoRA adapter only support " f"'bitsandbytes' quantization, but got {self.quantization}") + + assert self.cpu_offload_gb >= 0, ( + "CPU offload space must be non-negative" + f", but got {self.cpu_offload_gb}") + multimodal_config = MultiModalConfig() device_config = DeviceConfig(device=self.device) @@ -666,7 +686,9 @@ def create_engine_config(self, ) -> EngineConfig: cache_dtype=self.kv_cache_dtype, num_gpu_blocks_override=self.num_gpu_blocks_override, sliding_window=model_config.get_sliding_window(), - enable_prefix_caching=self.enable_prefix_caching) + enable_prefix_caching=self.enable_prefix_caching, + cpu_offload_gb=self.cpu_offload_gb, + ) parallel_config = ParallelConfig( pipeline_parallel_size=self.pipeline_parallel_size, tensor_parallel_size=self.tensor_parallel_size, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 57e81a6317725..cadaffa0e30cf 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -69,6 +69,10 @@ class LLM: when their `best_of` sampling parameters are larger than 1. If all requests will have `best_of=1`, you can safely set this to 0. Otherwise, too small values may cause out-of-memory (OOM) errors. + cpu_offload_gb: The size (GiB) of CPU memory to use for offloading + the model weights. This virtually increases the GPU memory space + you can use to hold the model weights, at the cost of CPU-GPU data + transfer for every forward pass. enforce_eager: Whether to enforce eager execution. If True, we will disable CUDA graph and always execute the model in eager mode. If False, we will use CUDA graph and eager execution in hybrid. @@ -114,6 +118,7 @@ def __init__( seed: int = 0, gpu_memory_utilization: float = 0.9, swap_space: int = 4, + cpu_offload_gb: float = 0, enforce_eager: bool = False, max_context_len_to_capture: Optional[int] = None, max_seq_len_to_capture: int = 8192, @@ -141,6 +146,7 @@ def __init__( seed=seed, gpu_memory_utilization=gpu_memory_utilization, swap_space=swap_space, + cpu_offload_gb=cpu_offload_gb, enforce_eager=enforce_eager, max_context_len_to_capture=max_context_len_to_capture, max_seq_len_to_capture=max_seq_len_to_capture, diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index c135b20352203..b505d32db5985 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -1,8 +1,10 @@ from typing import Callable, Dict, List, Tuple import torch +from torch.func import functional_call from vllm.multimodal import BatchedTensors +from vllm.utils import is_pin_memory_available def merge_vision_embeddings(input_ids: torch.Tensor, @@ -52,6 +54,70 @@ def __init__(self, *args, **kwargs): super().__init__() +_CPU_OFFLOAD_BYTES = 0 +_CPU_OFFLOAD_MAX_BYTES = 0 + + +def set_cpu_offload_max_bytes(max_bytes: int) -> None: + global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES + _CPU_OFFLOAD_BYTES = 0 + _CPU_OFFLOAD_MAX_BYTES = max_bytes + + +def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module: + device = next(module.parameters()).device + + if device == torch.device("cpu"): + return module + + global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES + if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES: + return module + + pin_memory = is_pin_memory_available() + + # offload parameters to CPU + # use pin_memory if possible, which helps cudagraph capture speed + for p in module.parameters(): + if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES: + # we use per-parameter offloading + # one module might have some parameters offloaded and some not + break + + # `torch.empty_like` does not support `pin_memory` argument + cpu_data = torch.empty(size=p.data.size(), + dtype=p.data.dtype, + layout=p.data.layout, + device='cpu', + pin_memory=pin_memory) + cpu_data.copy_(p.data) + p.data = cpu_data + _CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size() + + state_dict: Dict[str, torch.Tensor] = module.state_dict() + + original_forward = module.forward + + def forward(*args, **kwargs): + module.forward = original_forward + device_state = { + # here we blindly call `to(device)` + # if the parameter is already on the device, it will be a no-op + k: v.to(device, non_blocking=True) + for k, v in state_dict.items() + } + output = functional_call(module, + device_state, + args=args, + kwargs=kwargs) + module.forward = forward + return output + + module.forward = forward + + return module + + def make_layers( num_hidden_layers: int, layer_fn: Callable[[], torch.nn.Module] ) -> Tuple[int, int, torch.nn.ModuleList]: @@ -64,9 +130,10 @@ def make_layers( get_pp_group().rank_in_group, get_pp_group().world_size) modules = torch.nn.ModuleList( - [PPMissingLayer() for _ in range(start_layer)] + - [layer_fn() for _ in range(start_layer, end_layer)] + - [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)]) + [PPMissingLayer() for _ in range(start_layer)] + [ + maybe_offload_to_cpu(layer_fn()) + for _ in range(start_layer, end_layer) + ] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)]) return start_layer, end_layer, modules diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 75a2607d0d9c4..d810443665024 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -39,6 +39,7 @@ from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.models.interfaces import (supports_lora, supports_vision) +from vllm.model_executor.models.utils import set_cpu_offload_max_bytes from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors, MultiModalInputs) from vllm.prompt_adapter.layers import PromptAdapterMapping @@ -544,6 +545,9 @@ def __init__( self.flashinfer_prefill_workspace_buffer = None self.flashinfer_prefill_wrapper = None + set_cpu_offload_max_bytes( + int(self.cache_config.cpu_offload_gb * 1024**3)) + def load_model(self) -> None: with CudaMemoryProfiler() as m: self.model = get_model(model_config=self.model_config, From d25877dd9b7a09d5f9552d9d185a86dc6497bc2f Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Wed, 17 Jul 2024 22:24:43 -0700 Subject: [PATCH 06/13] [BugFix] Avoid secondary error in ShmRingBuffer destructor (#6530) --- vllm/distributed/device_communicators/shm_broadcast.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index 151b08c1b996c..bfea106bc027d 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -119,9 +119,10 @@ def __reduce__(self): ) def __del__(self): - self.shared_memory.close() - if self.is_creator: - self.shared_memory.unlink() + if hasattr(self, "shared_memory"): + self.shared_memory.close() + if self.is_creator: + self.shared_memory.unlink() @contextmanager def get_data(self, current_idx: int): @@ -428,7 +429,6 @@ def enqueue(self, obj): def dequeue(self): if self._is_local_reader: - overflow = False with self.acquire_read() as buf: overflow = buf[0] == 1 if not overflow: From 61e592747c28c9fbd6861e48b825c796e09da02f Mon Sep 17 00:00:00 2001 From: Rui Qiao <161574667+ruisearch42@users.noreply.github.com> Date: Wed, 17 Jul 2024 22:27:09 -0700 Subject: [PATCH 07/13] [Core] Introduce SPMD worker execution using Ray accelerated DAG (#6032) Signed-off-by: Rui Qiao Co-authored-by: Stephanie Wang --- .buildkite/test-pipeline.yaml | 3 + vllm/engine/llm_engine.py | 5 + vllm/envs.py | 8 + vllm/executor/distributed_gpu_executor.py | 8 +- vllm/executor/ray_gpu_executor.py | 212 ++++++++++++++-------- vllm/executor/ray_utils.py | 14 +- vllm/executor/ray_xpu_executor.py | 58 +++--- vllm/worker/worker_base.py | 31 +++- 8 files changed, 218 insertions(+), 121 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 00fa86b4c448f..59b683437987c 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -84,6 +84,8 @@ steps: - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py + - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_RAY_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s distributed/test_basic_distributed_correctness.py + - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_RAY_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py - TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_multimodal_broadcast.py @@ -108,6 +110,7 @@ steps: # We want to test that models which use 2 GPUs work with 4 GPUs, which is why we duplicate them here. # See https://github.com/vllm-project/vllm/pull/5473#issuecomment-2166601837 for context. - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py + - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_RAY_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py - pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 68ca9a97a3c61..77539eab0db23 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -6,6 +6,7 @@ from transformers import PreTrainedTokenizer +import vllm.envs as envs from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ObservabilityConfig, ParallelConfig, @@ -414,6 +415,9 @@ def from_engine_args( elif distributed_executor_backend == "mp": from vllm.executor.multiproc_gpu_executor import ( MultiprocessingGPUExecutor) + assert not envs.VLLM_USE_RAY_SPMD_WORKER, ( + "multiprocessing distributed executor backend does not " + "support VLLM_USE_RAY_SPMD_WORKER=1") executor_class = MultiprocessingGPUExecutor else: from vllm.executor.gpu_executor import GPUExecutor @@ -426,6 +430,7 @@ def from_engine_args( usage_context=usage_context, stat_loggers=stat_loggers, ) + return engine def __reduce__(self): diff --git a/vllm/envs.py b/vllm/envs.py index f3b6d2788d392..595992e51db87 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -34,6 +34,7 @@ VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS: bool = False VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache") VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024 + VLLM_USE_RAY_SPMD_WORKER: bool = False VLLM_USE_RAY_COMPILED_DAG: bool = False VLLM_WORKER_MULTIPROC_METHOD: str = "fork" VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets") @@ -261,6 +262,13 @@ def get_default_config_root(): "VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS": lambda: bool(os.getenv("VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS", False)), + # If the env var is set, then all workers will execute as separate + # processes from the engine, and we use the same mechanism to trigger + # execution on all workers. + # Run vLLM with VLLM_USE_RAY_SPMD_WORKER=1 to enable it. + "VLLM_USE_RAY_SPMD_WORKER": + lambda: bool(os.getenv("VLLM_USE_RAY_SPMD_WORKER", 0)), + # If the env var is set, it uses the Ray's compiled DAG API # which optimizes the control plane overhead. # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. diff --git a/vllm/executor/distributed_gpu_executor.py b/vllm/executor/distributed_gpu_executor.py index 3db82eb1fe790..4df54a09e5e8c 100644 --- a/vllm/executor/distributed_gpu_executor.py +++ b/vllm/executor/distributed_gpu_executor.py @@ -64,8 +64,8 @@ def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks=num_cpu_blocks) def execute_model( - self, execute_model_req: ExecuteModelRequest - ) -> Optional[List[SamplerOutput]]: + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: if self.parallel_worker_tasks is None: self.parallel_worker_tasks = self._run_workers( "start_worker_execution_loop", @@ -73,7 +73,9 @@ def execute_model( **self.extra_execute_model_run_workers_kwargs) # Only the driver worker returns the sampling results. - return self._driver_execute_model(execute_model_req) + driver_outputs = self._driver_execute_model(execute_model_req) + assert driver_outputs is not None + return driver_outputs def stop_remote_worker_execution_loop(self) -> None: if self.parallel_worker_tasks is None: diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index edff9b6c93e09..92899ba5b0217 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -1,6 +1,5 @@ import asyncio import os -import pickle from collections import defaultdict from itertools import islice, repeat from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple @@ -23,12 +22,30 @@ logger = init_logger(__name__) -USE_RAY_COMPILED_DAG = envs.VLLM_USE_RAY_COMPILED_DAG - class RayGPUExecutor(DistributedGPUExecutor): def _init_executor(self) -> None: + # If the env var is set, it uses the Ray's compiled DAG API + # which optimizes the control plane overhead. + # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. + # Currently, this requires USE_RAY_SPMD_WORKER=True. + self.use_ray_compiled_dag = envs.VLLM_USE_RAY_COMPILED_DAG + # If the env var is set, then we do not distinguish between the + # "driver worker" vs other workers. Also, the rank 0 worker will + # be executed in a remote Ray worker. Currently this requires + # USE_RAY_COMPILED_DAG=True. + self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER + if self.use_ray_compiled_dag: + assert self.use_ray_spmd_worker, ( + "VLLM_USE_RAY_COMPILED_DAG=1 requires " + "VLLM_USE_RAY_SPMD_WORKER=1") + if self.use_ray_spmd_worker: + # TODO: Support SPMD worker for non-DAG Ray executor. + assert self.use_ray_compiled_dag, ( + "VLLM_USE_RAY_SPMD_WORKER=1 requires " + "VLLM_USE_RAY_COMPILED_DAG=1") + assert self.parallel_config.distributed_executor_backend == "ray" placement_group = self.parallel_config.placement_group @@ -40,11 +57,7 @@ def _init_executor(self) -> None: # Create the parallel GPU workers. self._init_workers_ray(placement_group) - self.forward_dag = None - if USE_RAY_COMPILED_DAG: - self.forward_dag = self._compiled_ray_dag() - self.extra_execute_model_run_workers_kwargs[ - "use_ray_compiled_dag"] = True + self.forward_dag: Optional["ray.dag.CompiledDAG"] = None def _configure_ray_workers_use_nsight(self, ray_remote_kwargs) -> Dict[str, Any]: @@ -110,21 +123,24 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", trust_remote_code=self.model_config.trust_remote_code, ) - worker_ip = ray.get(worker.get_node_ip.remote()) - if worker_ip == driver_ip and self.driver_dummy_worker is None: - # If the worker is on the same node as the driver, we use it - # as the resource holder for the driver process. - self.driver_dummy_worker = worker - self.driver_worker = RayWorkerWrapper( - worker_module_name=worker_module_name, - worker_class_name=worker_class_name, - trust_remote_code=self.model_config.trust_remote_code, - ) - else: - # Else, added to the list of workers. + if self.use_ray_spmd_worker: self.workers.append(worker) - - if self.driver_dummy_worker is None: + else: + worker_ip = ray.get(worker.get_node_ip.remote()) + if worker_ip == driver_ip and self.driver_dummy_worker is None: + # If the worker is on the same node as the driver, we use it + # as the resource holder for the driver process. + self.driver_dummy_worker = worker + self.driver_worker = RayWorkerWrapper( + worker_module_name=worker_module_name, + worker_class_name=worker_class_name, + trust_remote_code=self.model_config.trust_remote_code, + ) + else: + # Else, added to the list of workers. + self.workers.append(worker) + + if not self.use_ray_spmd_worker and self.driver_dummy_worker is None: raise ValueError( "Ray does not allocate any GPUs on the driver node. Consider " "adjusting the Ray placement group or running the driver on a " @@ -254,9 +270,23 @@ def _driver_execute_model( Passing None will cause the driver to stop the model execution loop running in each of the remote workers. """ + assert not self.use_ray_spmd_worker, ( + "driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1") return self.driver_worker.execute_method("execute_model", execute_model_req) + def execute_model( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + if not self.use_ray_spmd_worker: + return super().execute_model(execute_model_req) + + if self.forward_dag is None: + self.forward_dag = self._compiled_ray_dag(enable_asyncio=False) + + outputs = ray.get(self.forward_dag.execute(execute_model_req)) + return outputs[0] + def _run_workers( self, method: str, @@ -266,7 +296,6 @@ def _run_workers( all_kwargs: Optional[List[Dict[str, Any]]] = None, use_dummy_driver: bool = False, max_concurrent_workers: Optional[int] = None, - use_ray_compiled_dag: bool = False, **kwargs, ) -> Any: """Runs the given method on all workers. Can be used in the following @@ -281,6 +310,10 @@ def _run_workers( - all_args/all_kwargs: args/kwargs for each worker are specified individually """ + if self.use_ray_spmd_worker: + assert not async_run_tensor_parallel_workers_only, ( + "async_run_tensor_parallel_workers_only is not supported for " + "spmd mode.") if max_concurrent_workers: raise NotImplementedError( @@ -289,71 +322,69 @@ def _run_workers( count = len(self.workers) if not \ async_run_tensor_parallel_workers_only \ else len(self.non_driver_workers) + # If using SPMD worker, all workers are the same, so we should execute + # the args on all workers. Otherwise, we skip the first worker's args + # because those args will go to the driver worker. + first_worker_args_index: int = 0 if self.use_ray_spmd_worker else 1 all_worker_args = repeat(args, count) if all_args is None \ - else islice(all_args, 1, None) + else islice(all_args, first_worker_args_index, None) all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \ - else islice(all_kwargs, 1, None) - - if use_ray_compiled_dag: - # Right now, compiled DAG can only accept a single - # input. TODO(sang): Fix it. - assert self.forward_dag is not None - output_channels = self.forward_dag.execute(1) - ray_worker_outputs = [] - else: - # Start the ray workers first. - ray_workers = self.workers - if async_run_tensor_parallel_workers_only: - ray_workers = self.non_driver_workers - ray_worker_outputs = [ - worker.execute_method.remote(method, *worker_args, - **worker_kwargs) - for (worker, worker_args, worker_kwargs - ) in zip(ray_workers, all_worker_args, all_worker_kwargs) - ] + else islice(all_kwargs, first_worker_args_index, None) + + # Start the ray workers first. + ray_workers = self.workers + if async_run_tensor_parallel_workers_only: + ray_workers = self.non_driver_workers + ray_worker_outputs = [ + worker.execute_method.remote(method, *worker_args, **worker_kwargs) + for (worker, worker_args, worker_kwargs + ) in zip(ray_workers, all_worker_args, all_worker_kwargs) + ] if async_run_tensor_parallel_workers_only: # Just return futures return ray_worker_outputs - driver_args = args if all_args is None else all_args[0] - driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] + driver_worker_output = [] + # In SPMD mode, the driver worker is the same as any other worker, + # so we only explicitly execute on the driver worker if using a + # non-SPMD worker class. + if not self.use_ray_spmd_worker: + driver_args = args if all_args is None else all_args[0] + driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] + + # Start the driver worker after all the ray workers. + if not use_dummy_driver: + driver_worker_output = [ + self.driver_worker.execute_method(method, *driver_args, + **driver_kwargs) + ] + else: + assert self.driver_dummy_worker is not None + driver_worker_output = [ + ray.get( + self.driver_dummy_worker.execute_method.remote( + method, *driver_args, **driver_kwargs)) + ] - # Start the driver worker after all the ray workers. - if not use_dummy_driver: - driver_worker_output = self.driver_worker.execute_method( - method, *driver_args, **driver_kwargs) - else: - assert self.driver_dummy_worker is not None - driver_worker_output = ray.get( - self.driver_dummy_worker.execute_method.remote( - method, *driver_args, **driver_kwargs)) # Get the results of the ray workers. if self.workers: - if use_ray_compiled_dag: - try: - ray_worker_outputs = [ - pickle.loads(chan.begin_read()) - for chan in output_channels - ] - finally: - # Has to call end_read in order to reuse the DAG. - for chan in output_channels: - chan.end_read() - else: - ray_worker_outputs = ray.get(ray_worker_outputs) + ray_worker_outputs = ray.get(ray_worker_outputs) - return [driver_worker_output] + ray_worker_outputs + return driver_worker_output + ray_worker_outputs def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: """Wait for futures returned from _run_workers() with async_run_remote_workers_only to complete.""" ray.get(parallel_worker_tasks) - def _compiled_ray_dag(self): + def _compiled_ray_dag(self, enable_asyncio: bool): import pkg_resources - required_version = "2.9" - current_version = pkg_resources.get_distribution("ray").version + from packaging import version + + required_version = version.parse("2.32") + current_version = version.parse( + pkg_resources.get_distribution("ray").version) if current_version < required_version: raise ValueError(f"Ray version {required_version} or greater is " f"required, but found {current_version}") @@ -365,23 +396,47 @@ def _compiled_ray_dag(self): # a dummy value for now. It will be fixed soon. with InputNode() as input_data: forward_dag = MultiOutputNode([ - worker.execute_model_compiled_dag_remote. - bind( # type: ignore[attr-defined] + worker.execute_model_spmd.bind( # type: ignore[attr-defined] input_data) for worker in self.workers ]) - return forward_dag.experimental_compile() + return forward_dag.experimental_compile(enable_asyncio=enable_asyncio) + + def __del__(self): + if self.forward_dag is not None: + self.forward_dag.teardown() + import ray + for worker in self.workers: + ray.kill(worker) class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.driver_exec_method = make_async(self.driver_worker.execute_method) + self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER + if not self.use_ray_compiled_dag: + self.driver_exec_method = make_async( + self.driver_worker.execute_method) + + async def execute_model_async( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + if not self.use_ray_spmd_worker: + return await super().execute_model_async(execute_model_req) + + if self.forward_dag is None: + self.forward_dag = self._compiled_ray_dag(enable_asyncio=True) + + dag_future = await self.forward_dag.execute_async(execute_model_req) + outputs = await dag_future + return outputs[0] async def _driver_execute_model_async( self, execute_model_req: Optional[ExecuteModelRequest] = None ) -> List[SamplerOutput]: + assert not self.use_ray_spmd_worker, ( + "driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1") if self.pp_locks is None: # This locks each pipeline parallel stage so multiple virtual # engines can't execute on the same stage at the same time @@ -415,8 +470,17 @@ async def _run_task_with_lock(task, lock, *args, **kwargs): return results[-1] async def _start_worker_execution_loop(self): + assert not self.use_ray_spmd_worker, ( + "worker loop is disabled for VLLM_USE_RAY_SPMD_WORKER=1") coros = [ worker.execute_method.remote("start_worker_execution_loop") for worker in self.non_driver_workers ] return await asyncio.gather(*coros) + + def __del__(self): + if self.forward_dag is not None: + self.forward_dag.teardown() + import ray + for worker in self.workers: + ray.kill(worker) diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index 242d6c136655f..fcbfa30d7a38a 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -1,8 +1,8 @@ -import pickle from typing import List, Optional, Tuple from vllm.config import ParallelConfig from vllm.logger import init_logger +from vllm.sequence import ExecuteModelRequest from vllm.utils import get_ip, is_hip, is_xpu from vllm.worker.worker_base import WorkerWrapperBase @@ -31,16 +31,18 @@ def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]: gpu_ids = ray.get_gpu_ids() return node_id, gpu_ids - def execute_model_compiled_dag_remote(self, ignored): - """Used only when compiled DAG is enabled.""" + def execute_model_spmd(self, execute_model_req: ExecuteModelRequest): + """Used only when SPMD worker and compiled DAG are both + enabled.""" + # TODO(swang): This is needed right now because Ray aDAG executes + # on a background thread, so we need to reset torch's current + # device. import torch if not self.compiled_dag_cuda_device_set: torch.cuda.set_device(self.worker.device) self.compiled_dag_cuda_device_set = True - output = self.worker.execute_model() - output = pickle.dumps(output) - return output + return self.worker._execute_model_spmd(execute_model_req) ray_import_err = None diff --git a/vllm/executor/ray_xpu_executor.py b/vllm/executor/ray_xpu_executor.py index 33f9321b5ff36..2a93616ced06c 100644 --- a/vllm/executor/ray_xpu_executor.py +++ b/vllm/executor/ray_xpu_executor.py @@ -1,11 +1,11 @@ import asyncio import os -import pickle from collections import defaultdict from itertools import islice, repeat from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Set, Tuple, Union) +import vllm.envs as envs from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig, @@ -30,7 +30,7 @@ # If the env var is set, it uses the Ray's compiled DAG API # which optimizes the control plane overhead. # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. -USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0)) +USE_RAY_COMPILED_DAG = envs.VLLM_USE_RAY_COMPILED_DAG class RayXPUExecutor(DistributedGPUExecutor): @@ -72,10 +72,9 @@ def __init__( # Create the parallel GPU workers. self._init_workers_ray(placement_group) - # Profile the memory usage and initialize the cache. self.forward_dag = None if USE_RAY_COMPILED_DAG: - self.forward_dag = self._compiled_ray_dag() + self.forward_dag = self._compiled_ray_dag(enable_asyncio=False) # This is non-None when the execute model loop is running # in the parallel workers. It's a coroutine in the AsyncLLMEngine case. @@ -270,7 +269,6 @@ def _run_workers( all_kwargs: Optional[List[Dict[str, Any]]] = None, use_dummy_driver: bool = False, max_concurrent_workers: Optional[int] = None, - use_ray_compiled_dag: bool = False, **kwargs, ) -> Any: """Runs the given method on all workers. Can be used in the following @@ -293,26 +291,20 @@ def _run_workers( all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \ else islice(all_kwargs, 1, None) - if use_ray_compiled_dag: - # Right now, compiled DAG can only accept a single - # input. TODO(sang): Fix it. - assert self.forward_dag is not None - output_channels = self.forward_dag.execute(1) - else: - # Start the ray workers first. - ray_worker_outputs = [ - worker.execute_method.remote(method, *worker_args, - **worker_kwargs) - for (worker, worker_args, worker_kwargs - ) in zip(self.workers, all_worker_args, all_worker_kwargs) - ] + # Start the ray workers first. + ray_worker_outputs = [ + worker.execute_method.remote(method, *worker_args, **worker_kwargs) + for (worker, worker_args, worker_kwargs + ) in zip(self.workers, all_worker_args, all_worker_kwargs) + ] + if async_run_remote_workers_only: # Just return futures return ray_worker_outputs + driver_worker_output = [] driver_args = args if all_args is None else all_args[0] driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] - # Start the driver worker after all the ray workers. if not use_dummy_driver: driver_worker_output = self.driver_worker.execute_method( @@ -324,36 +316,28 @@ def _run_workers( method, *driver_args, **driver_kwargs)) # Get the results of the ray workers. if self.workers: - if use_ray_compiled_dag: - try: - ray_worker_outputs = [ - pickle.loads(chan.begin_read()) - for chan in output_channels - ] - finally: - # Has to call end_read in order to reuse the DAG. - for chan in output_channels: - chan.end_read() - else: - ray_worker_outputs = ray.get(ray_worker_outputs) + ray_worker_outputs = ray.get(ray_worker_outputs) - return [driver_worker_output] + ray_worker_outputs + return driver_worker_output + ray_worker_outputs def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: """Wait for futures returned from _run_workers() with async_run_remote_workers_only to complete.""" ray.get(parallel_worker_tasks) - def _compiled_ray_dag(self): + def _compiled_ray_dag(self, enable_asyncio: bool): import pkg_resources - required_version = "2.9" - current_version = pkg_resources.get_distribution("ray").version + from packaging import version + + required_version = version.parse("2.32") + current_version = version.parse( + pkg_resources.get_distribution("ray").version) if current_version < required_version: raise ValueError(f"Ray version {required_version} or greater is " f"required, but found {current_version}") from ray.dag import InputNode, MultiOutputNode - assert self.parallel_config.worker_use_ray + assert self.parallel_config.distributed_executor_backend == "ray" # Right now, compiled DAG requires at least 1 arg. We send # a dummy value for now. It will be fixed soon. @@ -363,7 +347,7 @@ def _compiled_ray_dag(self): bind( # type: ignore[attr-defined] input_data) for worker in self.workers ]) - return forward_dag.experimental_compile() + return forward_dag.experimental_compile(enable_asyncio=enable_asyncio) def check_health(self) -> None: """Raises an error if engine is unhealthy.""" diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 93ffea9106501..a10281b02db89 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -281,6 +281,33 @@ def execute_model( # list to conform to interface. return output + def _execute_model_spmd( + self, execute_model_req: ExecuteModelRequest + ) -> Optional[List[SamplerOutput]]: + """ + Execute model in Single Program Multiple Data (SPMD) fashion. + All workers take the same request, prepare the input and + execute the model. + """ + assert execute_model_req is not None, ( + "_execute_model_spmd() requires each worker to take in an " + "ExecuteModelRequest") + worker_input: WorkerInput = self.prepare_worker_input( + execute_model_req=execute_model_req) + model_input: ModelRunnerInputBase = ( + self.model_runner.prepare_model_input( + execute_model_req.seq_group_metadata_list)) + + self.execute_worker(worker_input) + + # If there is no input, we don't need to execute the model. + if worker_input.num_seq_groups == 0: + return [] + + return self.model_runner.execute_model( + model_input, self.kv_cache[worker_input.virtual_engine] + if self.kv_cache is not None else None) + class WorkerWrapperBase: """ @@ -296,7 +323,7 @@ def __init__(self, trust_remote_code: bool = False) -> None: self.worker_module_name = worker_module_name self.worker_class_name = worker_class_name - self.worker = None + self.worker: Optional[WorkerBase] = None if trust_remote_code: # note: lazy import to avoid importing torch before initializing from vllm.utils import init_cached_hf_modules @@ -323,7 +350,9 @@ def init_worker(self, *args, **kwargs): mod = importlib.import_module(self.worker_module_name) worker_class = getattr(mod, self.worker_class_name) + self.worker = worker_class(*args, **kwargs) + assert self.worker is not None def execute_method(self, method, *args, **kwargs): try: From 8a74c68bd1ae48cb71e4c3bf9d7ff9a2ef8f9dae Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Wed, 17 Jul 2024 23:06:21 -0700 Subject: [PATCH 08/13] [Misc] Minor patch for draft model runner (#6523) --- vllm/spec_decode/draft_model_runner.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 3cb7ec58da4c1..d2c7e6e3710a8 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -15,8 +15,12 @@ logger = init_logger(__name__) +# A flag to enable debug prints for the updated input tensors +# before each step. debug_advance_input = False -enable_gpu_advance_step = True +# A flag to allow GPU advance step for draft model runner. +# Set to False for debugging. +allow_gpu_advance_step = True class TP1DraftModelRunner(ModelRunner): @@ -196,7 +200,7 @@ def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest): 3. No LORA 4. No prompt_adapter_config """ - if not enable_gpu_advance_step: + if not allow_gpu_advance_step: return False # We allow multi-step GPU only in decode mode From e2fbaee7258810bcd43725e9ca7f1444a88f91f3 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 18 Jul 2024 00:13:30 -0700 Subject: [PATCH 09/13] [BugFix][Frontend] Use LoRA tokenizer in OpenAI APIs (#6227) Co-authored-by: Cyrus Leung --- tests/async_engine/test_chat_template.py | 33 ++------- tests/entrypoints/openai/test_chat.py | 13 +--- tests/entrypoints/openai/test_completion.py | 51 ++++++++++++- tests/entrypoints/openai/test_serving_chat.py | 3 +- tests/entrypoints/openai/test_tokenization.py | 56 ++++++++++---- vllm/engine/async_llm_engine.py | 13 +++- vllm/engine/llm_engine.py | 7 +- vllm/entrypoints/openai/api_server.py | 3 +- vllm/entrypoints/openai/chat_utils.py | 72 +++++++++--------- vllm/entrypoints/openai/serving_chat.py | 73 +++++++++++-------- vllm/entrypoints/openai/serving_completion.py | 47 ++++++------ vllm/entrypoints/openai/serving_embedding.py | 11 +-- vllm/entrypoints/openai/serving_engine.py | 30 ++++---- .../openai/serving_tokenization.py | 30 +++++--- vllm/transformers_utils/detokenizer.py | 8 ++ vllm/transformers_utils/tokenizer.py | 3 + 16 files changed, 267 insertions(+), 186 deletions(-) diff --git a/tests/async_engine/test_chat_template.py b/tests/async_engine/test_chat_template.py index 536a7c96a1e9e..528d6ff182dd0 100644 --- a/tests/async_engine/test_chat_template.py +++ b/tests/async_engine/test_chat_template.py @@ -1,6 +1,5 @@ import os import pathlib -from dataclasses import dataclass import pytest @@ -50,23 +49,9 @@ ] -@dataclass -class MockTokenizer: - chat_template = None - - -@dataclass -class MockServingChat: - tokenizer: MockTokenizer - - def test_load_chat_template(): # Testing chatml template - tokenizer = MockTokenizer() - mock_serving_chat = MockServingChat(tokenizer) - load_chat_template(mock_serving_chat, chat_template=chatml_jinja_path) - - template_content = tokenizer.chat_template + template_content = load_chat_template(chat_template=chatml_jinja_path) # Test assertions assert template_content is not None @@ -78,22 +63,16 @@ def test_load_chat_template(): def test_no_load_chat_template_filelike(): # Testing chatml template template = "../../examples/does_not_exist" - tokenizer = MockTokenizer() - - mock_serving_chat = MockServingChat(tokenizer) with pytest.raises(ValueError, match="looks like a file path"): - load_chat_template(mock_serving_chat, chat_template=template) + load_chat_template(chat_template=template) def test_no_load_chat_template_literallike(): # Testing chatml template template = "{{ messages }}" - tokenizer = MockTokenizer() - mock_serving_chat = MockServingChat(tokenizer) - load_chat_template(mock_serving_chat, chat_template=template) - template_content = tokenizer.chat_template + template_content = load_chat_template(chat_template=template) assert template_content == template @@ -105,8 +84,7 @@ def test_get_gen_prompt(model, template, add_generation_prompt, expected_output): # Initialize the tokenizer tokenizer = get_tokenizer(tokenizer_name=model) - mock_serving_chat = MockServingChat(tokenizer) - load_chat_template(mock_serving_chat, chat_template=template) + template_content = load_chat_template(chat_template=template) # Create a mock request object using keyword arguments mock_request = ChatCompletionRequest( @@ -118,7 +96,8 @@ def test_get_gen_prompt(model, template, add_generation_prompt, result = tokenizer.apply_chat_template( conversation=mock_request.messages, tokenize=False, - add_generation_prompt=mock_request.add_generation_prompt) + add_generation_prompt=mock_request.add_generation_prompt, + chat_template=mock_request.chat_template or template_content) # Test assertion assert result == expected_output, ( diff --git a/tests/entrypoints/openai/test_chat.py b/tests/entrypoints/openai/test_chat.py index 8f67dd54edff0..1abaa01ae192a 100644 --- a/tests/entrypoints/openai/test_chat.py +++ b/tests/entrypoints/openai/test_chat.py @@ -7,11 +7,11 @@ import openai # use the official client for correctness check import pytest import torch -# downloading lora to test lora requests -from huggingface_hub import snapshot_download from openai import BadRequestError from ...utils import RemoteOpenAIServer +from .test_completion import zephyr_lora_added_tokens_files # noqa: F401 +from .test_completion import zephyr_lora_files # noqa: F401 # any model with a chat template should work here MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" @@ -21,12 +21,7 @@ @pytest.fixture(scope="module") -def zephyr_lora_files(): - return snapshot_download(repo_id=LORA_NAME) - - -@pytest.fixture(scope="module") -def server(zephyr_lora_files): +def server(zephyr_lora_files, zephyr_lora_added_tokens_files): # noqa: F811 args = [ # use half precision for speed and memory savings in CI environment "--dtype", @@ -38,7 +33,7 @@ def server(zephyr_lora_files): "--enable-lora", "--lora-modules", f"zephyr-lora={zephyr_lora_files}", - f"zephyr-lora2={zephyr_lora_files}", + f"zephyr-lora2={zephyr_lora_added_tokens_files}", "--max-lora-rank", "64", "--max-cpu-loras", diff --git a/tests/entrypoints/openai/test_completion.py b/tests/entrypoints/openai/test_completion.py index 59151b9c4e99e..0896e337b5d24 100644 --- a/tests/entrypoints/openai/test_completion.py +++ b/tests/entrypoints/openai/test_completion.py @@ -1,6 +1,8 @@ # imports for guided decoding tests import json import re +import shutil +from tempfile import TemporaryDirectory from typing import List import jsonschema @@ -9,6 +11,7 @@ # downloading lora to test lora requests from huggingface_hub import snapshot_download from openai import BadRequestError +from transformers import AutoTokenizer from vllm.transformers_utils.tokenizer import get_tokenizer @@ -30,13 +33,29 @@ def zephyr_lora_files(): return snapshot_download(repo_id=LORA_NAME) +@pytest.fixture(scope="module") +def zephyr_lora_added_tokens_files(zephyr_lora_files): + tmp_dir = TemporaryDirectory() + tmp_model_dir = f"{tmp_dir.name}/zephyr" + shutil.copytree(zephyr_lora_files, tmp_model_dir) + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + # Copy tokenizer to adapter and add some unique tokens + # 32000, 32001, 32002 + added = tokenizer.add_tokens(["vllm1", "vllm2", "vllm3"], + special_tokens=True) + assert added == 3 + tokenizer.save_pretrained(tmp_model_dir) + yield tmp_model_dir + tmp_dir.cleanup() + + @pytest.fixture(scope="module") def zephyr_pa_files(): return snapshot_download(repo_id=PA_NAME) @pytest.fixture(scope="module") -def server(zephyr_lora_files, zephyr_pa_files): +def server(zephyr_lora_files, zephyr_lora_added_tokens_files, zephyr_pa_files): args = [ # use half precision for speed and memory savings in CI environment "--dtype", @@ -50,7 +69,7 @@ def server(zephyr_lora_files, zephyr_pa_files): "--enable-lora", "--lora-modules", f"zephyr-lora={zephyr_lora_files}", - f"zephyr-lora2={zephyr_lora_files}", + f"zephyr-lora2={zephyr_lora_added_tokens_files}", "--max-lora-rank", "64", "--max-cpu-loras", @@ -111,6 +130,34 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str, assert len(completion.choices[0].text) >= 1 +@pytest.mark.asyncio +async def test_added_lora_tokens(client: openai.AsyncOpenAI): + # test using token IDs + completion = await client.completions.create( + model="zephyr-lora2", + prompt=[0, 0, 32000, 32001, 32002], + echo=True, + max_tokens=5, + temperature=0.0, + ) + # Added tokens should appear in tokenized prompt + assert completion.choices[0].text.startswith("vllm1vllm2vllm3") + + +@pytest.mark.asyncio +async def test_added_lora_tokens_base_model(client: openai.AsyncOpenAI): + # test using token IDs + completion = await client.completions.create( + model=MODEL_NAME, + prompt=[0, 0, 32000, 32001, 32002], + echo=True, + max_tokens=5, + temperature=0.0, + ) + # Added tokens should not appear in tokenized prompt + assert "vllm" not in completion.choices[0].text + + @pytest.mark.asyncio @pytest.mark.parametrize( # first test base model, then test loras, then test prompt adapters diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 74b49726734b5..9a7abcfe5e590 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -38,5 +38,4 @@ async def _async_serving_chat_init(): def test_async_serving_chat_init(): serving_completion = asyncio.run(_async_serving_chat_init()) - assert serving_completion.tokenizer is not None - assert serving_completion.tokenizer.chat_template == CHAT_TEMPLATE + assert serving_completion.chat_template == CHAT_TEMPLATE diff --git a/tests/entrypoints/openai/test_tokenization.py b/tests/entrypoints/openai/test_tokenization.py index f32abba225d40..18c51c560b511 100644 --- a/tests/entrypoints/openai/test_tokenization.py +++ b/tests/entrypoints/openai/test_tokenization.py @@ -5,13 +5,15 @@ from vllm.transformers_utils.tokenizer import get_tokenizer from ...utils import RemoteOpenAIServer +from .test_completion import zephyr_lora_added_tokens_files # noqa: F401 +from .test_completion import zephyr_lora_files # noqa: F401 # any model with a chat template should work here MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" @pytest.fixture(scope="module") -def server(): +def server(zephyr_lora_added_tokens_files: str): # noqa: F811 args = [ # use half precision for speed and memory savings in CI environment "--dtype", @@ -21,12 +23,25 @@ def server(): "--enforce-eager", "--max-num-seqs", "128", + # lora config + "--enable-lora", + "--lora-modules", + f"zephyr-lora2={zephyr_lora_added_tokens_files}", + "--max-lora-rank", + "64", ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: yield remote_server +@pytest.fixture(scope="module") +def tokenizer_name(model_name: str, + zephyr_lora_added_tokens_files: str): # noqa: F811 + return zephyr_lora_added_tokens_files if ( + model_name == "zephyr-lora2") else model_name + + @pytest.fixture(scope="module") def client(server): return server.get_async_client() @@ -34,16 +49,18 @@ def client(server): @pytest.mark.asyncio @pytest.mark.parametrize( - "model_name", - [MODEL_NAME], + "model_name,tokenizer_name", + [(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")], + indirect=["tokenizer_name"], ) async def test_tokenize_completions(client: openai.AsyncOpenAI, - model_name: str): + model_name: str, tokenizer_name: str): base_url = str(client.base_url)[:-3].strip("/") - tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast") + tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, + tokenizer_mode="fast") for add_special in [False, True]: - prompt = "This is a test prompt." + prompt = "vllm1 This is a test prompt." tokens = tokenizer.encode(prompt, add_special_tokens=add_special) response = requests.post(base_url + "/tokenize", @@ -63,12 +80,15 @@ async def test_tokenize_completions(client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize( - "model_name", - [MODEL_NAME], + "model_name,tokenizer_name", + [(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")], + indirect=["tokenizer_name"], ) -async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str): +async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str, + tokenizer_name: str): base_url = str(client.base_url)[:-3].strip("/") - tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast") + tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, + tokenizer_mode="fast") for add_generation in [False, True]: for add_special in [False, True]: @@ -80,7 +100,7 @@ async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str): "content": "Nice to meet you!" }, { "role": "user", - "content": "Can I ask a question?" + "content": "Can I ask a question? vllm1" }] prompt = tokenizer.apply_chat_template( @@ -108,16 +128,20 @@ async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str): @pytest.mark.asyncio @pytest.mark.parametrize( - "model_name", - [MODEL_NAME], + "model_name,tokenizer_name", + [(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")], + indirect=["tokenizer_name"], ) -async def test_detokenize(client: openai.AsyncOpenAI, model_name: str): +async def test_detokenize(client: openai.AsyncOpenAI, model_name: str, + tokenizer_name: str): base_url = str(client.base_url)[:-3].strip("/") - tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast") + tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, + tokenizer_mode="fast") - prompt = "This is a test prompt." + prompt = "This is a test prompt. vllm1" tokens = tokenizer.encode(prompt, add_special_tokens=False) + print(f"CALLING {base_url} FOR {model_name}") response = requests.post(base_url + "/detokenize", json={ "model": model_name, diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 0e63506e7c367..8bced12a14347 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -480,11 +480,16 @@ def _error_callback(self, exc: Exception) -> None: self.set_errored(exc) self._request_tracker.propagate_exception(exc) - async def get_tokenizer(self) -> "PreTrainedTokenizer": + async def get_tokenizer( + self, + lora_request: Optional[LoRARequest] = None, + ) -> "PreTrainedTokenizer": if self.engine_use_ray: - return await self.engine.get_tokenizer.remote() # type: ignore - else: - return self.engine.get_tokenizer() + return await self.engine.get_tokenizer.remote( # type: ignore + lora_request) + + return await (self.engine.get_tokenizer_group(). + get_lora_tokenizer_async(lora_request)) def start_background_loop(self) -> None: """Start the background loop.""" diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 77539eab0db23..0937349827eda 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -455,8 +455,11 @@ def get_tokenizer_group( return self.tokenizer - def get_tokenizer(self) -> "PreTrainedTokenizer": - return self.get_tokenizer_group().get_lora_tokenizer(None) + def get_tokenizer( + self, + lora_request: Optional[LoRARequest] = None + ) -> "PreTrainedTokenizer": + return self.get_tokenizer_group().get_lora_tokenizer(lora_request) def get_tokenizer_for_seq(self, sequence: Sequence) -> "PreTrainedTokenizer": diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index a35dcbbd6545e..b6bf08e5fae60 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -257,7 +257,8 @@ def run_server(args, llm_engine=None): openai_serving_embedding = OpenAIServingEmbedding(engine, model_config, served_model_names) openai_serving_tokenization = OpenAIServingTokenization( - engine, model_config, served_model_names, args.chat_template) + engine, model_config, served_model_names, args.lora_modules, + args.chat_template) app.root_path = args.root_path logger.info("Available routes are:") diff --git a/vllm/entrypoints/openai/chat_utils.py b/vllm/entrypoints/openai/chat_utils.py index 27115391d5b27..b3d5ca77ac16d 100644 --- a/vllm/entrypoints/openai/chat_utils.py +++ b/vllm/entrypoints/openai/chat_utils.py @@ -5,10 +5,11 @@ from openai.types.chat import (ChatCompletionContentPartImageParam, ChatCompletionContentPartTextParam) +from transformers import PreTrainedTokenizer +from vllm.config import ModelConfig from vllm.entrypoints.openai.protocol import (ChatCompletionContentPartParam, ChatCompletionMessageParam) -from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.logger import init_logger from vllm.multimodal import MultiModalDataDict from vllm.multimodal.utils import async_get_and_parse_image @@ -29,40 +30,34 @@ class ChatMessageParseResult: default_factory=list) -def load_chat_template(engine: OpenAIServing, chat_template: Optional[str]): - tokenizer = engine.tokenizer - - if chat_template is not None: - try: - with open(chat_template, "r") as f: - tokenizer.chat_template = f.read() - except OSError as e: - JINJA_CHARS = "{}\n" - if not any(c in chat_template for c in JINJA_CHARS): - msg = (f"The supplied chat template ({chat_template}) " - f"looks like a file path, but it failed to be " - f"opened. Reason: {e}") - raise ValueError(msg) from e +def load_chat_template(chat_template: Optional[str]) -> Optional[str]: + if chat_template is None: + return None + try: + with open(chat_template, "r") as f: + resolved_chat_template = f.read() + except OSError as e: + JINJA_CHARS = "{}\n" + if not any(c in chat_template for c in JINJA_CHARS): + msg = (f"The supplied chat template ({chat_template}) " + f"looks like a file path, but it failed to be " + f"opened. Reason: {e}") + raise ValueError(msg) from e - # If opening a file fails, set chat template to be args to - # ensure we decode so our escape are interpreted correctly - tokenizer.chat_template = codecs.decode(chat_template, - "unicode_escape") + # If opening a file fails, set chat template to be args to + # ensure we decode so our escape are interpreted correctly + resolved_chat_template = codecs.decode(chat_template, "unicode_escape") - logger.info("Using supplied chat template:\n%s", - tokenizer.chat_template) - elif tokenizer.chat_template is not None: - logger.info("Using default chat template:\n%s", - tokenizer.chat_template) - else: - logger.warning("No chat template provided. Chat API will not work.") + logger.info("Using supplied chat template:\n%s", resolved_chat_template) + return resolved_chat_template @lru_cache(maxsize=None) -def _image_token_str(engine: OpenAIServing) -> Optional[str]: +def _image_token_str(model_config: ModelConfig, + tokenizer: PreTrainedTokenizer) -> Optional[str]: # TODO: Let user specify how to insert image tokens into prompt # (similar to chat template) - model_type = engine.model_config.hf_config.model_type + model_type = model_config.hf_config.model_type if model_type == "phi3_v": # Workaround since this token is not defined in the tokenizer return "<|image_1|>" @@ -70,17 +65,14 @@ def _image_token_str(engine: OpenAIServing) -> Optional[str]: # These models do not use image tokens in the prompt return None if model_type.startswith("llava"): - return engine.tokenizer.decode( - engine.model_config.hf_config.image_token_index) + return tokenizer.decode(model_config.hf_config.image_token_index) - else: - raise TypeError("Unknown model type: {model_type}") + raise TypeError("Unknown model type: {model_type}") # TODO: Let user specify how to insert image tokens into prompt # (similar to chat template) -def _get_full_image_text_prompt(engine: OpenAIServing, image_token_str: str, - text_prompt: str) -> str: +def _get_full_image_text_prompt(image_token_str: str, text_prompt: str) -> str: """Combine image and text prompts for vision language model""" # NOTE: For now we assume all model architectures use the same @@ -89,9 +81,10 @@ def _get_full_image_text_prompt(engine: OpenAIServing, image_token_str: str, def _parse_chat_message_content_parts( - engine: OpenAIServing, role: str, parts: Iterable[ChatCompletionContentPartParam], + model_config: ModelConfig, + tokenizer: PreTrainedTokenizer, ) -> ChatMessageParseResult: texts: List[str] = [] mm_futures: List[Awaitable[MultiModalDataDict]] = [] @@ -122,7 +115,7 @@ def _parse_chat_message_content_parts( text_prompt = "\n".join(texts) if mm_futures: - image_token_str = _image_token_str(engine) + image_token_str = _image_token_str(model_config, tokenizer) if image_token_str is not None: if image_token_str in text_prompt: logger.warning( @@ -130,7 +123,6 @@ def _parse_chat_message_content_parts( "Skipping prompt formatting.") else: text_prompt = _get_full_image_text_prompt( - engine, image_token_str=image_token_str, text_prompt=text_prompt, ) @@ -141,8 +133,9 @@ def _parse_chat_message_content_parts( def parse_chat_message_content( - engine: OpenAIServing, message: ChatCompletionMessageParam, + model_config: ModelConfig, + tokenizer: PreTrainedTokenizer, ) -> ChatMessageParseResult: role = message["role"] content = message.get("content") @@ -153,4 +146,5 @@ def parse_chat_message_content( messages = [ConversationMessage(role=role, content=content)] return ChatMessageParseResult(messages=messages, mm_futures=[]) - return _parse_chat_message_content_parts(engine, role, content) + return _parse_chat_message_content_parts(role, content, model_config, + tokenizer) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index dbd4521073da9..0d7eede377ce5 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -5,6 +5,7 @@ from typing import Union from fastapi import Request +from transformers import PreTrainedTokenizer from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine @@ -49,7 +50,9 @@ def __init__(self, lora_modules=lora_modules) self.response_role = response_role - load_chat_template(self, chat_template) + + # If this is None we use the tokenizer's default chat template + self.chat_template = load_chat_template(chat_template) async def create_chat_completion( self, @@ -71,11 +74,15 @@ async def create_chat_completion( return error_check_ret try: + _, lora_request = self._maybe_get_adapter(request) + tokenizer = await self.engine.get_tokenizer(lora_request) + conversation: List[ConversationMessage] = [] mm_futures: List[Awaitable[MultiModalDataDict]] = [] for msg in request.messages: - chat_parsed_result = parse_chat_message_content(self, msg) + chat_parsed_result = parse_chat_message_content( + msg, self.model_config, tokenizer) conversation.extend(chat_parsed_result.messages) mm_futures.extend(chat_parsed_result.mm_futures) @@ -84,13 +91,13 @@ async def create_chat_completion( tool.model_dump() for tool in request.tools ] - prompt = self.tokenizer.apply_chat_template( + prompt = tokenizer.apply_chat_template( conversation=conversation, tokenize=False, add_generation_prompt=request.add_generation_prompt, tools=tool_dicts, documents=request.documents, - chat_template=request.chat_template, + chat_template=request.chat_template or self.chat_template, **(request.chat_template_kwargs or {}), ) except Exception as e: @@ -112,19 +119,19 @@ async def create_chat_completion( request_id = f"cmpl-{random_uuid()}" try: # Tokenize/detokenize depending on prompt format (string/token list) - prompt_ids, prompt_text = self._validate_prompt_and_tokenize( + prompt_ids, prompt_text = await self._validate_prompt_and_tokenize( request, + tokenizer, prompt=prompt, add_special_tokens=request.add_special_tokens) sampling_params = request.to_sampling_params() - _, lora_request = self._maybe_get_adapter(request) decoding_config = await self.engine.get_decoding_config() guided_decoding_backend = request.guided_decoding_backend \ or decoding_config.guided_decoding_backend guided_decode_logits_processor = ( - await get_guided_decoding_logits_processor( - guided_decoding_backend, request, await - self.engine.get_tokenizer())) + await + get_guided_decoding_logits_processor(guided_decoding_backend, + request, tokenizer)) if guided_decode_logits_processor: if sampling_params.logits_processors is None: sampling_params.logits_processors = [] @@ -158,12 +165,12 @@ async def create_chat_completion( # Streaming response if request.stream: return self.chat_completion_stream_generator( - request, result_generator, request_id, conversation) + request, result_generator, request_id, conversation, tokenizer) else: try: return await self.chat_completion_full_generator( request, raw_request, result_generator, request_id, - conversation) + conversation, tokenizer) except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) @@ -175,9 +182,12 @@ def get_chat_request_role(self, request: ChatCompletionRequest) -> str: return request.messages[-1]["role"] async def chat_completion_stream_generator( - self, request: ChatCompletionRequest, - result_generator: AsyncIterator[RequestOutput], request_id: str, - conversation: List[ConversationMessage] + self, + request: ChatCompletionRequest, + result_generator: AsyncIterator[RequestOutput], + request_id: str, + conversation: List[ConversationMessage], + tokenizer: PreTrainedTokenizer, ) -> AsyncGenerator[str, None]: model_name = self.served_model_names[0] created_time = int(time.time()) @@ -264,6 +274,7 @@ async def chat_completion_stream_generator( logprobs = self._create_chat_logprobs( token_ids=delta_token_ids, top_logprobs=out_logprobs, + tokenizer=tokenizer, num_output_top_logprobs=request.top_logprobs, ) else: @@ -352,9 +363,13 @@ async def chat_completion_stream_generator( yield "data: [DONE]\n\n" async def chat_completion_full_generator( - self, request: ChatCompletionRequest, raw_request: Optional[Request], - result_generator: AsyncIterator[RequestOutput], request_id: str, - conversation: List[ConversationMessage] + self, + request: ChatCompletionRequest, + raw_request: Optional[Request], + result_generator: AsyncIterator[RequestOutput], + request_id: str, + conversation: List[ConversationMessage], + tokenizer: PreTrainedTokenizer, ) -> Union[ErrorResponse, ChatCompletionResponse]: model_name = self.served_model_names[0] @@ -382,6 +397,7 @@ async def chat_completion_full_generator( token_ids=token_ids, top_logprobs=out_logprobs, num_output_top_logprobs=request.top_logprobs, + tokenizer=tokenizer, ) else: logprobs = None @@ -436,16 +452,14 @@ async def chat_completion_full_generator( return response def _get_top_logprobs( - self, logprobs: Dict[int, Logprob], - top_logprobs: Optional[int]) -> List[ChatCompletionLogProb]: + self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int], + tokenizer: PreTrainedTokenizer) -> List[ChatCompletionLogProb]: return [ ChatCompletionLogProb( - token=self._get_decoded_token(p[1], p[0]), + token=(token := self._get_decoded_token(p[1], p[0], + tokenizer)), logprob=max(p[1].logprob, -9999.0), - bytes=list( - self._get_decoded_token(p[1], - p[0]).encode("utf-8", - errors="replace"))) + bytes=list(token.encode("utf-8", errors="replace"))) for i, p in enumerate(logprobs.items()) if top_logprobs and i < top_logprobs ] @@ -454,6 +468,7 @@ def _create_chat_logprobs( self, token_ids: GenericSequence[int], top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]], + tokenizer: PreTrainedTokenizer, num_output_top_logprobs: Optional[int] = None, ) -> ChatCompletionLogProbs: """Create OpenAI-style logprobs.""" @@ -463,12 +478,11 @@ def _create_chat_logprobs( for i, token_id in enumerate(token_ids): step_top_logprobs = top_logprobs[i] if step_top_logprobs is None: + token = tokenizer.decode(token_id) logprobs_content.append( ChatCompletionLogProbsContent( - token=self.tokenizer.decode(token_id), - bytes=list( - self.tokenizer.decode(token_id).encode( - "utf-8", errors="replace")))) + token=token, + bytes=list(token.encode("utf-8", errors="replace")))) else: logprobs_content.append( ChatCompletionLogProbsContent( @@ -479,6 +493,7 @@ def _create_chat_logprobs( step_top_logprobs[token_id].decoded_token.encode( "utf-8", errors="replace")), top_logprobs=self._get_top_logprobs( - step_top_logprobs, num_output_top_logprobs))) + step_top_logprobs, num_output_top_logprobs, + tokenizer))) return ChatCompletionLogProbs(content=logprobs_content) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 647fc31410647..e61f3fdbf6666 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -5,6 +5,7 @@ from typing import Tuple from fastapi import Request +from transformers import PreTrainedTokenizer from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine @@ -100,20 +101,22 @@ async def create_completion(self, request: CompletionRequest, # Schedule the request and get the result generator. generators: List[AsyncIterator[RequestOutput]] = [] try: - sampling_params = request.to_sampling_params() adapter_type, adapter_request = self._maybe_get_adapter(request) lora_request, prompt_adapter_request = None, None if adapter_type == 'LoRA': lora_request, prompt_adapter_request = adapter_request, None elif adapter_type == 'PromptAdapter': lora_request, prompt_adapter_request = None, adapter_request + tokenizer = await self.engine.get_tokenizer(lora_request) + + sampling_params = request.to_sampling_params() decoding_config = await self.engine.get_decoding_config() guided_decoding_backend = request.guided_decoding_backend \ or decoding_config.guided_decoding_backend guided_decode_logit_processor = ( - await get_guided_decoding_logits_processor( - guided_decoding_backend, request, await - self.engine.get_tokenizer())) + await + get_guided_decoding_logits_processor(guided_decoding_backend, + request, tokenizer)) if guided_decode_logit_processor is not None: if sampling_params.logits_processors is None: sampling_params.logits_processors = [] @@ -122,18 +125,13 @@ async def create_completion(self, request: CompletionRequest, prompt_is_tokens, prompts = parse_prompt_format(request.prompt) for i, prompt in enumerate(prompts): - if prompt_is_tokens: - prompt_formats = self._validate_prompt_and_tokenize( - request, - prompt_ids=prompt, - truncate_prompt_tokens=sampling_params. - truncate_prompt_tokens) - else: - prompt_formats = self._validate_prompt_and_tokenize( - request, - prompt=prompt, - truncate_prompt_tokens=sampling_params. - truncate_prompt_tokens) + prompt_arg = "prompt_ids" if prompt_is_tokens else "prompt" + prompt_formats = await self._validate_prompt_and_tokenize( + request, + tokenizer, + truncate_prompt_tokens=sampling_params. + truncate_prompt_tokens, + **{prompt_arg: prompt}) prompt_ids, prompt_text = prompt_formats is_tracing_enabled = await self.engine.is_tracing_enabled() @@ -179,7 +177,8 @@ async def create_completion(self, request: CompletionRequest, request_id, created_time, model_name, - num_prompts=len(prompts)) + num_prompts=len(prompts), + tokenizer=tokenizer) # Non-streaming response final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts) @@ -191,7 +190,8 @@ async def create_completion(self, request: CompletionRequest, return self.create_error_response("Client disconnected") final_res_batch[i] = res response = self.request_output_to_completion_response( - final_res_batch, request, request_id, created_time, model_name) + final_res_batch, request, request_id, created_time, model_name, + tokenizer) except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) @@ -218,6 +218,7 @@ async def completion_stream_generator( created_time: int, model_name: str, num_prompts: int, + tokenizer: PreTrainedTokenizer, ) -> AsyncGenerator[str, None]: assert request.n is not None previous_texts = [""] * request.n * num_prompts @@ -268,6 +269,7 @@ async def completion_stream_generator( token_ids=delta_token_ids, top_logprobs=out_logprobs, num_output_top_logprobs=request.logprobs, + tokenizer=tokenizer, initial_text_offset=len(previous_texts[i]), ) else: @@ -336,6 +338,7 @@ def request_output_to_completion_response( request_id: str, created_time: int, model_name: str, + tokenizer: PreTrainedTokenizer, ) -> CompletionResponse: choices: List[CompletionResponseChoice] = [] num_prompt_tokens = 0 @@ -367,6 +370,7 @@ def request_output_to_completion_response( logprobs = self._create_completion_logprobs( token_ids=token_ids, top_logprobs=out_logprobs, + tokenizer=tokenizer, num_output_top_logprobs=request.logprobs, ) else: @@ -404,6 +408,7 @@ def _create_completion_logprobs( token_ids: GenericSequence[int], top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]], num_output_top_logprobs: int, + tokenizer: PreTrainedTokenizer, initial_text_offset: int = 0, ) -> CompletionLogProbs: """Create logprobs for OpenAI Completion API.""" @@ -417,13 +422,13 @@ def _create_completion_logprobs( for i, token_id in enumerate(token_ids): step_top_logprobs = top_logprobs[i] if step_top_logprobs is None: - token = self.tokenizer.decode(token_id) + token = tokenizer.decode(token_id) out_tokens.append(token) out_token_logprobs.append(None) out_top_logprobs.append(None) else: token = self._get_decoded_token(step_top_logprobs[token_id], - token_id) + token_id, tokenizer) token_logprob = max(step_top_logprobs[token_id].logprob, -9999.0) out_tokens.append(token) @@ -436,7 +441,7 @@ def _create_completion_logprobs( out_top_logprobs.append({ # Convert float("-inf") to the # JSON-serializable float that OpenAI uses - self._get_decoded_token(top_lp[1], top_lp[0]): + self._get_decoded_token(top_lp[1], top_lp[0], tokenizer): max(top_lp[1].logprob, -9999.0) for i, top_lp in enumerate(step_top_logprobs.items()) if num_output_top_logprobs >= i diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 4838cb7d0255a..19e4288f5aa1c 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -89,14 +89,11 @@ async def create_embedding(self, request: EmbeddingRequest, prompt_is_tokens, prompts = parse_prompt_format(request.input) pooling_params = request.to_pooling_params() + tokenizer = await self.engine.get_tokenizer() for i, prompt in enumerate(prompts): - if prompt_is_tokens: - prompt_formats = self._validate_prompt_and_tokenize( - request, prompt_ids=prompt) - else: - prompt_formats = self._validate_prompt_and_tokenize( - request, prompt=prompt) - + prompt_arg = "prompt_ids" if prompt_is_tokens else "prompt" + prompt_formats = await self._validate_prompt_and_tokenize( + request, tokenizer, **{prompt_arg: prompt}) prompt_ids, prompt_text = prompt_formats generator = self.engine.encode( diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 14c1df89e064f..4123ace36479e 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union from pydantic import Field +from transformers import PreTrainedTokenizer from typing_extensions import Annotated from vllm.config import ModelConfig @@ -19,7 +20,6 @@ from vllm.lora.request import LoRARequest from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import Logprob -from vllm.transformers_utils.tokenizer import get_tokenizer logger = init_logger(__name__) @@ -52,14 +52,6 @@ def __init__( self.model_config = model_config self.max_model_len = model_config.max_model_len - # A separate tokenizer to map token IDs to strings. - self.tokenizer = get_tokenizer( - model_config.tokenizer, - tokenizer_mode=model_config.tokenizer_mode, - tokenizer_revision=model_config.tokenizer_revision, - trust_remote_code=model_config.trust_remote_code, - truncation_side="left") - self.served_model_names = served_model_names self.lora_requests = [] @@ -154,7 +146,8 @@ async def _check_model( def _maybe_get_adapter( self, request: Union[CompletionRequest, ChatCompletionRequest, - EmbeddingRequest] + EmbeddingRequest, TokenizeRequest, + DetokenizeRequest] ) -> Tuple[Optional[str], Optional[Union[LoRARequest, PromptAdapterRequest]]]: if request.model in self.served_model_names: @@ -168,11 +161,12 @@ def _maybe_get_adapter( # if _check_model has been called earlier, this will be unreachable raise ValueError(f"The model `{request.model}` does not exist.") - def _validate_prompt_and_tokenize( + async def _validate_prompt_and_tokenize( self, request: Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest, EmbeddingRequest, TokenizeRequest], + tokenizer: "PreTrainedTokenizer", prompt: Optional[str] = None, prompt_ids: Optional[List[int]] = None, truncate_prompt_tokens: Optional[Annotated[int, @@ -181,7 +175,7 @@ def _validate_prompt_and_tokenize( ) -> Tuple[List[int], str]: if not (prompt or prompt_ids): raise ValueError("Either prompt or prompt_ids should be provided.") - if (prompt and prompt_ids): + if prompt and prompt_ids: raise ValueError( "Only one of prompt or prompt_ids should be provided.") @@ -200,14 +194,14 @@ def _validate_prompt_and_tokenize( "truncation": True, "max_length": truncate_prompt_tokens, }) - input_ids = self.tokenizer(prompt, **tokenizer_kwargs).input_ids + input_ids = tokenizer(prompt, **tokenizer_kwargs).input_ids elif truncate_prompt_tokens is not None: input_ids = prompt_ids[-truncate_prompt_tokens:] else: input_ids = prompt_ids - input_text = prompt if prompt is not None else self.tokenizer.decode( - prompt_ids) + input_text = prompt if prompt is not None else tokenizer.decode( + input_ids) token_num = len(input_ids) # Note: EmbeddingRequest doesn't have max_tokens @@ -245,7 +239,9 @@ def _validate_prompt_and_tokenize( else: return input_ids, input_text - def _get_decoded_token(self, logprob: Logprob, token_id: int) -> str: + @staticmethod + def _get_decoded_token(logprob: Logprob, token_id: int, + tokenizer: PreTrainedTokenizer) -> str: if logprob.decoded_token is not None: return logprob.decoded_token - return self.tokenizer.decode(token_id) + return tokenizer.decode(token_id) diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index f441e940c5e5f..94367bd3a6048 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -9,7 +9,8 @@ DetokenizeResponse, TokenizeRequest, TokenizeResponse) -from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, + OpenAIServing) class OpenAIServingTokenization(OpenAIServing): @@ -18,13 +19,15 @@ def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig, served_model_names: List[str], + lora_modules: Optional[List[LoRAModulePath]] = None, chat_template: Optional[str] = None): super().__init__(engine=engine, model_config=model_config, served_model_names=served_model_names, - lora_modules=None) + lora_modules=lora_modules) - load_chat_template(self, chat_template) + # If this is None we use the tokenizer's default chat template + self.chat_template = load_chat_template(chat_template) async def create_tokenize(self, request: TokenizeRequest) -> TokenizeResponse: @@ -40,20 +43,25 @@ async def create_tokenize(self, return self.create_error_response( "Only one of `prompt` or `messages` should be provided.") + _, lora_request = self._maybe_get_adapter(request) + tokenizer = await self.engine.get_tokenizer(lora_request) if request.messages: conversation: List[ConversationMessage] = [] for message in request.messages: - conversation.extend( - parse_chat_message_content(self, message).messages) + result = parse_chat_message_content(message, self.model_config, + tokenizer) + conversation.extend(result.messages) - request.prompt = self.tokenizer.apply_chat_template( + request.prompt = tokenizer.apply_chat_template( add_generation_prompt=request.add_generation_prompt, conversation=conversation, - tokenize=False) + tokenize=False, + chat_template=self.chat_template) - (input_ids, input_text) = self._validate_prompt_and_tokenize( + (input_ids, input_text) = await self._validate_prompt_and_tokenize( request, + tokenizer, prompt=request.prompt, add_special_tokens=request.add_special_tokens) @@ -67,7 +75,9 @@ async def create_detokenize( if error_check_ret is not None: return error_check_ret - (input_ids, input_text) = self._validate_prompt_and_tokenize( - request, prompt_ids=request.tokens) + _, lora_request = self._maybe_get_adapter(request) + tokenizer = await self.engine.get_tokenizer(lora_request) + (input_ids, input_text) = await self._validate_prompt_and_tokenize( + request, tokenizer, prompt_ids=request.tokens) return DetokenizeResponse(prompt=input_text) diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py index cc9a971301afc..0a45028e7759b 100644 --- a/vllm/transformers_utils/detokenizer.py +++ b/vllm/transformers_utils/detokenizer.py @@ -165,6 +165,12 @@ def decode_sequence_inplace(self, seq: Sequence, return len(new_decoded_token_text) +def _replace_none_with_empty(tokens: List[Optional[str]]): + for i, token in enumerate(tokens): + if token is None: + tokens[i] = "" + + def _convert_tokens_to_string_with_added_encoders( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], output_tokens: List[str], @@ -223,6 +229,8 @@ def convert_prompt_ids_to_tokens( read_offset = len(new_tokens) prefix_offset = max( read_offset - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0) + # This is required to guard against out-of-vocab prompt token ids + _replace_none_with_empty(new_tokens) return new_tokens, prefix_offset, read_offset diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index f5684dbf1271c..7553249544211 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -88,6 +88,9 @@ def get_tokenizer( "Cannot use the fast tokenizer in slow tokenizer mode.") kwargs["use_fast"] = False + if "truncation_side" not in kwargs: + kwargs["truncation_side"] = "left" + try: tokenizer = AutoTokenizer.from_pretrained( tokenizer_name, From c8a7d51c4982b7b425debe5473867d9983e728fd Mon Sep 17 00:00:00 2001 From: Noam Gat Date: Thu, 18 Jul 2024 10:47:13 +0300 Subject: [PATCH 10/13] [Bugfix] Update flashinfer.py with PagedAttention forwards - Fixes Gemma2 OpenAI Server Crash (#6501) --- vllm/attention/backends/flashinfer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index daff76051a956..9c25b2cc2ba97 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -20,6 +20,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, compute_slot_mapping_start_idx, is_block_tables_empty) +from vllm.attention.ops.paged_attn import PagedAttention from vllm.sequence import SequenceGroupMetadata from vllm.utils import get_kv_cache_torch_dtype, make_tensor_with_pad @@ -61,14 +62,14 @@ def swap_blocks( dst_kv_cache: torch.Tensor, src_to_dst: torch.Tensor, ) -> None: - raise NotImplementedError + PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], src_to_dists: torch.Tensor, ) -> None: - raise NotImplementedError + PagedAttention.copy_blocks(kv_caches, src_to_dists) @staticmethod def get_supported_head_sizes() -> List[int]: From 4634c8728b0fe30a0b2da22dd4a4c8d8a1d9213b Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 18 Jul 2024 01:34:16 -0700 Subject: [PATCH 11/13] [TPU] Refactor TPU worker & model runner (#6506) --- vllm/worker/tpu_model_runner.py | 297 +++++++++++++++++++++----------- vllm/worker/tpu_worker.py | 141 +++++++-------- 2 files changed, 272 insertions(+), 166 deletions(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index bbf0db31ee383..8a8b412db6731 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -1,5 +1,6 @@ import time -from typing import List, Optional, Tuple +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union import numpy as np import torch @@ -12,10 +13,16 @@ from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, - SamplerOutput, SequenceGroupMetadata, +from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, + Logprob, SamplerOutput, SequenceGroupMetadata, SequenceOutput) -from vllm.utils import make_tensor_with_pad +from vllm.worker.model_runner_base import ( + ModelRunnerBase, ModelRunnerInputBase, + _add_attn_metadata_broadcastable_dict, + _init_attn_metadata_from_tensor_dict) + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend logger = init_logger(__name__) @@ -27,7 +34,44 @@ _MAX_NUM_SAMPLES = 128 -class TPUModelRunner: +@dataclass(frozen=True) +class ModelInputForTPU(ModelRunnerInputBase): + token_ids: torch.Tensor + position_ids: torch.Tensor + attn_metadata: AttentionMetadata + input_lens: torch.Tensor + t: torch.Tensor + p: torch.Tensor + num_samples: int + best_of: List[int] + seq_groups: List[List[int]] + + def as_broadcastable_tensor_dict( + self) -> Dict[str, Union[int, torch.Tensor]]: + tensor_dict = { + "token_ids": self.token_ids, + "position_ids": self.position_ids, + "input_lens": self.input_lens, + "t": self.t, + "p": self.p, + "num_samples": self.num_samples, + } + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + return tensor_dict + + @classmethod + def from_broadcasted_tensor_dict( + cls: Type["ModelInputForTPU"], + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> "ModelInputForTPU": + if attn_backend is not None: + tensor_dict = _init_attn_metadata_from_tensor_dict( + attn_backend, tensor_dict) + return cls(**tensor_dict) + + +class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): def __init__( self, @@ -79,6 +123,7 @@ def load_model(self) -> None: multimodal_config=self.multimodal_config, lora_config=None, ) + model = model.eval() xm.wait_device_ops() model = ModelWrapper(model) @@ -147,8 +192,8 @@ def _dummy_run( # Dummy run. num_samples = _MAX_NUM_SAMPLES if is_prompt else 1 - self.model(token_ids, position_ids, kv_caches, attn_metadata, - input_lens, t, p, num_samples) + self.model(token_ids, position_ids, attn_metadata, input_lens, t, p, + num_samples, kv_caches) def warmup_model( self, @@ -177,7 +222,7 @@ def warmup_model( # Decode start = time.time() seq_len = 1 - batch_size = 1 + batch_size = 8 # Must be in sync with _get_padded_batch_size() while True: self._dummy_run(batch_size, seq_len, kv_caches, is_prompt=False) xm.wait_device_ops() @@ -195,10 +240,10 @@ def _prepare_prompt( seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]: assert len(seq_group_metadata_list) > 0 - input_tokens: List[List[int]] = [] - input_positions: List[List[int]] = [] + input_tokens: List[int] = [] + input_positions: List[int] = [] prompt_lens: List[int] = [] - slot_mapping: List[List[int]] = [] + slot_mapping: List[int] = [] for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt @@ -212,50 +257,46 @@ def _prepare_prompt( prompt_len = len(prompt_tokens) prompt_lens.append(prompt_len) - input_tokens.append(prompt_tokens) - input_positions.append(list(range(prompt_len))) + input_tokens.extend(prompt_tokens) + input_positions.extend(list(range(prompt_len))) assert seq_group_metadata.block_tables is not None block_table = seq_group_metadata.block_tables[seq_id] - slot_mapping.append([]) for i in range(prompt_len): block_number = block_table[i // self.block_size] block_offset = i % self.block_size slot = block_number * self.block_size + block_offset - slot_mapping[-1].append(slot) + slot_mapping.append(slot) + + # Add paddings to EACH prompt to the smallest power of 2 that is + # greater than or equal to the prompt length. + # We pad the seq_len to reduce the compilation overhead. + # We execute each prompt individually (i.e., with batch_size 1) + # because the FlashAttention kernel does not support ragged inputs. + # TODO(woosuk): Use SplashAttention to support ragged inputs. + padded_prompt_len = _get_padded_prefill_len(prompt_len) + num_paddings = padded_prompt_len - prompt_len + input_tokens += [0] * num_paddings + input_positions += [0] * num_paddings + slot_mapping += [_PAD_SLOT_ID] * num_paddings assert len(prompt_lens) > 0 num_prefills = len(prompt_lens) - num_prefill_tokens = sum(prompt_lens) - - # Add paddings to make the shape [batch_size, max_prompt_len] where - # max_prompt_len is smallest power of 2 that is greater than or equal - # to the maximum prompt length. - # We need the 2D input shape because the Pallas FlashAttention kernel - # does not support packed 1D inputs. - # We pad the seq_len to powers of 2 to reduce the compilation overhead. - max_prompt_len = _get_padded_prefill_len(max(prompt_lens)) - input_tokens = make_tensor_with_pad(input_tokens, - max_prompt_len, - pad=0, - dtype=torch.int32, - device=self.device) - input_positions = make_tensor_with_pad(input_positions, - max_prompt_len, - pad=0, - dtype=torch.int32, - device=self.device) - slot_mapping = make_tensor_with_pad(slot_mapping, - max_prompt_len, - pad=_PAD_SLOT_ID, - dtype=torch.int64, - device=self.device) + input_tokens = torch.tensor(input_tokens, + dtype=torch.int32, + device="cpu") + input_positions = torch.tensor(input_positions, + dtype=torch.int32, + device="cpu") + slot_mapping = torch.tensor(slot_mapping, + dtype=torch.int64, + device="cpu") prompt_lens = torch.tensor(prompt_lens, dtype=torch.int32, - device=self.device) + device="cpu") attn_metadata = self.attn_backend.make_metadata( num_prefills=num_prefills, - num_prefill_tokens=num_prefill_tokens, # NOTE: This is not used. + num_prefill_tokens=0, # NOTE: This is not used. num_decode_tokens=0, slot_mapping=slot_mapping, block_tables=None, @@ -306,22 +347,22 @@ def _prepare_decode( input_tokens = torch.tensor(input_tokens, dtype=torch.int32, - device=self.device) + device="cpu") input_positions = torch.tensor(input_positions, dtype=torch.int32, - device=self.device) + device="cpu") slot_mapping = torch.tensor(slot_mapping, dtype=torch.int64, - device=self.device) + device="cpu") context_lens = torch.tensor(context_lens, dtype=torch.int32, - device=self.device) + device="cpu") block_tables = torch.tensor(self.block_tables[:batch_size], dtype=torch.int32, - device=self.device) + device="cpu") input_lens = torch.tensor([1] * batch_size, dtype=torch.int32, - device=self.device) + device="cpu") attn_metadata = self.attn_backend.make_metadata( num_prefills=0, num_prefill_tokens=0, @@ -382,16 +423,18 @@ def _prepare_sample( t += [1.0] * num_paddings p += [1.0] * num_paddings - t = torch.tensor(t, dtype=torch.float32, device=self.device) - p = torch.tensor(p, dtype=torch.float32, device=self.device) + t = torch.tensor(t, dtype=torch.float32, device="cpu") + p = torch.tensor(p, dtype=torch.float32, device="cpu") return t, p, best_of - def _execute_model( + def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], - kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], - ) -> List[CompletionSequenceGroupOutput]: - # Prepare inputs. + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None, + ) -> ModelInputForTPU: + del finished_requests_ids # Unused. + assert virtual_engine == 0 assert len(seq_group_metadata_list) > 0 # NOTE: We assume that all sequences in the group are all prompts or # all decodes. @@ -400,16 +443,104 @@ def _execute_model( inputs = self._prepare_prompt(seq_group_metadata_list) else: inputs = self._prepare_decode(seq_group_metadata_list) - padded_batch_size = inputs[0].shape[0] + input_tokens, input_positions, attn_metadata, input_lens = inputs + padded_batch_size = input_tokens.shape[0] t, p, best_of = self._prepare_sample(seq_group_metadata_list, padded_batch_size) num_samples = _MAX_NUM_SAMPLES if is_prompt else 1 - # Execute the model. - next_token_ids = self.model(inputs[0], inputs[1], kv_caches, - *inputs[2:], t, p, num_samples) - # Retrieve the outputs to CPU. - next_token_ids = next_token_ids.cpu().tolist() + seq_groups = [ + list(metadata.seq_data.keys()) + for metadata in seq_group_metadata_list + ] + return ModelInputForTPU(input_tokens, input_positions, attn_metadata, + input_lens, t, p, num_samples, best_of, + seq_groups) + + def make_model_input_from_broadcasted_tensor_dict( + self, tensor_dict: Dict[str, Any]) -> ModelInputForTPU: + model_input = ModelInputForTPU.from_broadcasted_tensor_dict( + tensor_dict, attn_backend=self.attn_backend) + return model_input + + def execute_model( + self, + model_input: ModelInputForTPU, + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> List[SamplerOutput]: + assert intermediate_tensors is None + if num_steps > 1: + raise ValueError( + "TPUModelRunner does not support multi-step execution.") + + def _execute_model(*args, clone: bool = False) -> torch.Tensor: + """Move input args from CPU to device and execute the model.""" + + def _copy_to_device(x: torch.Tensor) -> torch.Tensor: + if clone: + # When x is a slice of a CPU tensor, XLA may copy the whole + # original tensor to TPU instead of only copying x. + # To avoid this, we copy x after cloning. + x = x.clone() + return x.to(self.device) + + new_args = [] + for arg in args: + if isinstance(arg, torch.Tensor): + arg = _copy_to_device(arg) + elif isinstance(arg, AttentionMetadata): + arg.slot_mapping = _copy_to_device(arg.slot_mapping) + if getattr(arg, "block_tables", None) is not None: + arg.block_tables = _copy_to_device(arg.block_tables) + if getattr(arg, "context_lens", None) is not None: + arg.context_lens = _copy_to_device(arg.context_lens) + new_args.append(arg) + return self.model(*new_args) + + num_prefills = model_input.attn_metadata.num_prefills + is_prompt = num_prefills > 0 + if is_prompt: + # NOTE(woosuk): Since the FlashAttention kernel does not support + # ragged inputs, we split the prompts into different batches and + # process them separately. This is a temporary hack that should be + # optimized by using SplashAttention. + next_token_ids = [] + orig_slot_mapping = model_input.attn_metadata.slot_mapping + batch_size = model_input.input_lens.shape[0] + start_idx = 0 + for i in range(batch_size): + # Get the actual prefill_len. + prefill_len = model_input.input_lens[i:i + 1].item() + prefill_len = _get_padded_prefill_len(prefill_len) + end_idx = start_idx + prefill_len + + model_input.attn_metadata.slot_mapping = orig_slot_mapping[ + None, start_idx:end_idx] + model_input.attn_metadata.num_prefills = 1 + output_token_ids = _execute_model( + model_input.token_ids[None, start_idx:end_idx], + model_input.position_ids[None, start_idx:end_idx], + model_input.attn_metadata, + model_input.input_lens[i:i + 1], + model_input.t[i:i + 1], + model_input.p[i:i + 1], + model_input.num_samples, + kv_caches, + clone=True) + # Retrieve the outputs to CPU. + next_token_ids += output_token_ids.cpu().tolist() + start_idx = end_idx + else: + # Execute the model. + output_token_ids = _execute_model( + model_input.token_ids, model_input.position_ids, + model_input.attn_metadata, model_input.input_lens, + model_input.t, model_input.p, model_input.num_samples, + kv_caches) + # Retrieve the outputs to CPU. + next_token_ids = output_token_ids.cpu().tolist() # NOTE(woosuk): Minimal code to construct the sampler outputs. # The TPU backend does not reuse the sampler, since the TPU backend @@ -417,13 +548,13 @@ def _execute_model( zero_logprob = Logprob(0.0) batch_idx = 0 sampler_outputs = [] - for seq_group_metadata in seq_group_metadata_list: + for seq_group in model_input.seq_groups: + seq_ids = seq_group seq_outputs = [] - seq_ids = list(seq_group_metadata.seq_data.keys()) if is_prompt: assert len(seq_ids) == 1 seq_id = seq_ids[0] - for i in range(best_of[batch_idx]): + for i in range(model_input.best_of[batch_idx]): next_token_id = next_token_ids[batch_idx][i] seq_outputs.append( SequenceOutput(seq_id, next_token_id, @@ -438,35 +569,6 @@ def _execute_model( batch_idx += 1 sampler_outputs.append( CompletionSequenceGroupOutput(seq_outputs, None)) - return sampler_outputs - - def execute_model( - self, - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], - kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], - num_steps: int = 1, - ) -> List[SamplerOutput]: - if num_steps > 1: - raise ValueError( - "TPUModelRunner does not support multi-step execution.") - - assert seq_group_metadata_list is not None - assert len(seq_group_metadata_list) > 0 - if seq_group_metadata_list[0].is_prompt: - # NOTE(woosuk): To reduce the compilation time, we only compile the - # prefill inputs with batch size 1. Because the scheduler is not - # aware of this limitation, we need to handle batch size > 1 - # internally by calling the model multiple times and concatenating - # the outputs. - # FIXME(woosuk): This is a temporary hack to not change the existing - # scheduler. We need to fix this in the future. - sampler_outputs = [] - for seq_group_metadata in seq_group_metadata_list: - sampler_outputs += self._execute_model([seq_group_metadata], - kv_caches) - else: - sampler_outputs = self._execute_model(seq_group_metadata_list, - kv_caches) return [SamplerOutput(sampler_outputs)] @@ -474,36 +576,37 @@ class ModelWrapper(nn.Module): def __init__(self, model: nn.Module): super().__init__() - self.model = model.eval() + self.model = model def forward( self, token_ids: torch.Tensor, position_ids: torch.Tensor, - kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]], attn_metadata: AttentionMetadata, input_lens: torch.Tensor, t: torch.Tensor, p: torch.Tensor, num_samples: int, + kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]], ) -> torch.Tensor: """Executes the forward pass of the model and samples the next token. Args: token_ids: The input token IDs of shape [batch_size, seq_len]. position_ids: The input position IDs of shape [batch_size, seq_len]. - kv_caches: The key and value caches. They can be None during the - memory profiling at initialization. attn_metadata: The Pallas attention metadata. input_lens: The actual input lengths of shape [batch_size]. t: The sampling temperature of shape [batch_size]. p: The top-p probability of shape [batch_size]. + num_samples: Number of samples to draw from each logits vector. + kv_caches: The key and value caches. They can be None during the + memory profiling at initialization. """ batch_size, seq_len = token_ids.shape # Calculate the positions to sample from. - base_indicies = torch.arange( + start_indicies = torch.arange( batch_size, dtype=torch.int32, device=input_lens.device) * seq_len - logits_indices = base_indicies + input_lens - 1 + logits_indices = start_indicies + input_lens - 1 # FIXME(woosuk): This is a temporary hack to avoid using the existing # sampler and sampling metadata. diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index 9bf764f0ff23a..03011e03058d8 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -13,15 +13,16 @@ init_distributed_environment) from vllm.logger import init_logger from vllm.model_executor import set_random_seed -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.sequence import ExecuteModelRequest from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size from vllm.worker.tpu_model_runner import TPUModelRunner -from vllm.worker.worker_base import LoraNotSupportedWorkerBase +from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, + LoraNotSupportedWorkerBase, WorkerInput) logger = init_logger(__name__) -class TPUWorker(LoraNotSupportedWorkerBase): +class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): def __init__( self, @@ -57,14 +58,15 @@ def __init__( self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ self.cache_config.cache_dtype] - self.model_runner = TPUModelRunner(model_config, - parallel_config, - scheduler_config, - device_config, - cache_config, - load_config, - multimodal_config, - is_driver_worker=is_driver_worker) + self.model_runner: TPUModelRunner = TPUModelRunner( + model_config, + parallel_config, + scheduler_config, + device_config, + cache_config, + load_config, + multimodal_config, + is_driver_worker=is_driver_worker) def init_device(self) -> None: os.environ["PJRT_DEVICE"] = "TPU" @@ -196,69 +198,70 @@ def get_cache_block_size_bytes(self) -> int: dtype_size = get_dtype_size(self.cache_dtype) return dtype_size * total - def execute_model( + @property + def do_metadata_broadcast(self) -> bool: + # TODO(woosuk): Support TP. + return False + + @property + def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: + # NOTE(woosuk): This assumes virtual_engine == 0, i.e., no pipeline + # parallelism. + return [self.tpu_cache] + + def prepare_worker_input( self, - execute_model_req: Optional[ExecuteModelRequest] = None, - ) -> List[SamplerOutput]: - if not self.is_driver_worker: - self._execute_model_non_driver() - return [] - assert execute_model_req is not None - # Issue cache operations. - self.cache_swap( - execute_model_req.blocks_to_swap_in, - execute_model_req.blocks_to_swap_out, - execute_model_req.blocks_to_copy, + execute_model_req: ExecuteModelRequest, + ) -> WorkerInput: + virtual_engine = execute_model_req.virtual_engine + num_seq_groups = len(execute_model_req.seq_group_metadata_list) + blocks_to_swap_in = _make_src_to_dst( + execute_model_req.blocks_to_swap_in, "cpu", self.device) + blocks_to_swap_out = _make_src_to_dst( + execute_model_req.blocks_to_swap_out, self.device, "cpu") + blocks_to_copy = _make_src_to_dst(execute_model_req.blocks_to_copy, + self.device, self.device) + return WorkerInput( + num_seq_groups=num_seq_groups, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=blocks_to_copy, + virtual_engine=virtual_engine, ) - # Run the model. - seq_group_metadata_list = execute_model_req.seq_group_metadata_list - assert len(seq_group_metadata_list) > 0 - output = self.model_runner.execute_model(seq_group_metadata_list, - self.tpu_cache) - return output - - def cache_swap( - self, - blocks_to_swap_in: List[Tuple[int, int]], - blocks_to_swap_out: List[Tuple[int, int]], - blocks_to_copy: List[Tuple[int, int]], - ) -> None: + + def execute_worker(self, worker_input: WorkerInput) -> None: + virtual_engine = worker_input.virtual_engine + assert virtual_engine == 0 attn_backend = self.model_runner.attn_backend num_layers = self.model_config.get_num_layers(self.parallel_config) - if blocks_to_swap_in: - # Swap from CPU to TPU. - src_indices, dst_indices = _make_src_to_dst( - blocks_to_swap_in, "cpu", self.device) - for i in range(num_layers): - tpu_k_cache, tpu_v_cache = self.tpu_cache[i] - cpu_k_cache, cpu_v_cache = self.cpu_cache[i] - k = cpu_k_cache[:, src_indices].to(self.device) - v = cpu_v_cache[:, src_indices].to(self.device) - _insert_kv(k, v, dst_indices, tpu_k_cache, tpu_v_cache) - - if blocks_to_swap_out: - # Swap from TPU to CPU. - src_indices, dst_indices = _make_src_to_dst( - blocks_to_swap_out, self.device, "cpu") - for i in range(num_layers): - tpu_k_cache, tpu_v_cache = self.tpu_cache[i] - cpu_k_cache, cpu_v_cache = self.cpu_cache[i] - cpu_k_cache[:, dst_indices] = tpu_k_cache[:, src_indices].cpu() - cpu_v_cache[:, dst_indices] = tpu_v_cache[:, src_indices].cpu() - - if blocks_to_copy: - src_to_dst = _make_src_to_dst(blocks_to_copy, self.device, - self.device) - attn_backend.copy_blocks(self.tpu_cache, src_to_dst) - - def start_worker_execution_loop(self) -> None: - while self._execute_model_non_driver(): - pass - - def _execute_model_non_driver(self) -> bool: - self.model_runner.execute_model(None, self.tpu_cache) - return True + # Issue cache operations. + if worker_input.blocks_to_swap_in is not None: + src_indices, dst_indices = worker_input.blocks_to_swap_in + if src_indices.numel() > 0: + # Swap from CPU to TPU. + for i in range(num_layers): + tpu_k_cache, tpu_v_cache = self.tpu_cache[i] + cpu_k_cache, cpu_v_cache = self.cpu_cache[i] + k = cpu_k_cache[:, src_indices].to(self.device) + v = cpu_v_cache[:, src_indices].to(self.device) + _insert_kv(k, v, dst_indices, tpu_k_cache, tpu_v_cache) + + if worker_input.blocks_to_swap_out is not None: + src_indices, dst_indices = worker_input.blocks_to_swap_out + if src_indices.numel() > 0: + # Swap from TPU to CPU. + for i in range(num_layers): + tpu_k_cache, tpu_v_cache = self.tpu_cache[i] + cpu_k_cache, cpu_v_cache = self.cpu_cache[i] + cpu_k_cache[:, dst_indices] = tpu_k_cache[:, src_indices] + cpu_v_cache[:, dst_indices] = tpu_v_cache[:, src_indices] + + if worker_input.blocks_to_copy is not None: + src_indices, dst_indices = worker_input.blocks_to_copy + if src_indices.numel() > 0: + attn_backend.copy_blocks(self.tpu_cache, + (src_indices, dst_indices)) def _make_src_to_dst( From 58ca6632247cb738d069a585e1ec9a9d5e66da68 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Thu, 18 Jul 2024 10:39:12 -0400 Subject: [PATCH 12/13] [ Misc ] Improve Min Capability Checking in `compressed-tensors` (#6522) --- .../compressed_tensors/compressed_tensors.py | 22 ++++++++++++------- .../schemes/compressed_tensors_scheme.py | 7 ++++++ .../schemes/compressed_tensors_unquantized.py | 4 ++++ .../schemes/compressed_tensors_w4a16_24.py | 4 ++++ .../schemes/compressed_tensors_w8a8_fp8.py | 4 ++++ .../schemes/compressed_tensors_w8a8_int8.py | 4 ++++ .../schemes/compressed_tensors_wNa16.py | 4 ++++ 7 files changed, 41 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 1424c620ae675..659f5a599dc14 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -37,7 +37,7 @@ def get_supported_act_dtypes(cls) -> List[torch.dtype]: @classmethod def get_min_capability(cls) -> int: - return 75 + return 70 def get_name(self) -> str: return "compressed_tensors" @@ -85,13 +85,14 @@ def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": def get_config_filenames(cls) -> List[str]: return [] - def _check_gptq_and_marlin_can_run(self): + def _check_scheme_supported(self, min_capability: int): capability = current_platform.get_device_capability() capability = capability[0] * 10 + capability[1] - if capability < 80: - raise RuntimeError("The quantization config is not supported for ", - "the current GPU. Minimum capability: 80. ", - f"Current capability: {capability}.") + if capability < min_capability: + raise RuntimeError( + "Quantization scheme is not supported for ", + f"the current GPU. Min capability: {min_capability}. ", + f"Current capability: {capability}.") def _is_static_tensor_w8a8(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool: @@ -171,7 +172,6 @@ def _get_schema(self, weight_quant: BaseModel, # Detect If Mixed Precision if self._is_wNa16_group_channel(weight_quant, input_quant): - self._check_gptq_and_marlin_can_run() if (self.quant_format == CompressionFormat.marlin_24.value and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS): return CompressedTensorsW4A16Sparse24( @@ -222,10 +222,16 @@ def get_scheme(self, layer: torch.nn.Module) -> "CompressedTensorsScheme": raise ValueError( f"Could not find quantization details for {layer}.") - return self._get_schema( + scheme = self._get_schema( weight_quant=layer_quant_details["weights"], input_quant=layer_quant_details["input_activations"]) + # Raise error if device does not support the scheme + # (e.g. fp8 needs ada lovelace) + self._check_scheme_supported(scheme.get_min_capability()) + + return scheme + class CompressedTensorsLinearMethod(LinearMethodBase): diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py index 3aa9130782039..d5f37b47bb87e 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py @@ -12,6 +12,13 @@ class CompressedTensorsScheme(ABC): of different quantization schemes supported by CompressedTensors. """ + @abstractmethod + def get_min_capability(self) -> int: + """ + Get minimum device capability. + """ + raise NotImplementedError + @abstractmethod def create_weights(self, *args, **kwargs): """ diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py index 2c7fe3e0e4114..4350ff4e90ae8 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py @@ -18,6 +18,10 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme): in a linear transformation. """ + def get_min_capability(self) -> int: + # volta and up + return 70 + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: pass diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py index 54bf85c096f2e..eec523d00372c 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py @@ -29,6 +29,10 @@ def __init__(self, raise ValueError( "group_size must be given when using strategy group") + def get_min_capability(self) -> int: + # ampere + up + return 80 + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: pass diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index f1ca9510d92aa..e842475e4f34b 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -33,6 +33,10 @@ def __init__(self, strategy: str, is_static_input_scheme: bool): "Consider quantizing with per tensor scales or upgrading " "to Hopper.") + def get_min_capability(self) -> int: + # lovelace and up + return 89 + def process_weights_after_loading(self, layer) -> None: # If per tensor, when we have a fused module (e.g. QKV) with per # tensor scales (thus N scales being passed to the kernel), diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py index 6fec5d01056d8..e81496c89ac7f 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py @@ -19,6 +19,10 @@ def __init__(self, strategy: str, is_static_input_scheme: bool): self.strategy = strategy self.is_static_input_scheme = is_static_input_scheme + def get_min_capability(self) -> int: + # turing and up + return 75 + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # WEIGHT # Cutlass kernels need transposed weight. diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py index 187a3f9877ccf..3f3febcad4f85 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -42,6 +42,10 @@ def __init__(self, group_size=self.group_size, is_sym=True) + def get_min_capability(self) -> int: + # ampere and up + return 80 + def create_weights(self, layer: torch.nn.Module, input_size: int, output_partition_sizes: List[int], input_size_per_partition: int, From ecdb462c2493cf9fd095e624cdfd8d62842e2097 Mon Sep 17 00:00:00 2001 From: "Kevin H. Luu" Date: Thu, 18 Jul 2024 08:01:45 -0700 Subject: [PATCH 13/13] [ci] Reword Github bot comment (#6534) --- .github/workflows/reminder_comment.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/reminder_comment.yml b/.github/workflows/reminder_comment.yml index d6924a30aa406..68d242e053be1 100644 --- a/.github/workflows/reminder_comment.yml +++ b/.github/workflows/reminder_comment.yml @@ -15,7 +15,7 @@ jobs: owner: context.repo.owner, repo: context.repo.repo, issue_number: context.issue.number, - body: 'šŸ‘‹ Hi! Thank you for contributing to the vLLM project.\n Just a reminder: PRs would not trigger full CI run by default. Instead, it would only trigger `fastcheck` CI to run, which consists only a small and essential subset of tests to quickly catch errors with the flexibility to run extra individual tests on top (you can do this by unblocking test steps in the Buildkite run). \n\nFull CI run is still required to merge this PR so once the PR is ready to go, please make sure to run it. If you need all test signals in between PR commits, you can trigger full CI as well.\n\n To run full CI, you can do one of these:\n- Comment `/ready` on the PR\n- Add `ready` label to the PR\n- Enable auto-merge.\n\nšŸš€' + body: 'šŸ‘‹ Hi! Thank you for contributing to the vLLM project.\n Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run `fastcheck` CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your `fast-check` build on Buildkite UI. \n\nOnce the PR is approved and ready to go, please make sure to run full CI as it's required (or just use auto-merge).\n\n To run full CI, you can do one of these:\n- Comment `/ready` on the PR\n- Add `ready` label to the PR\n- Enable auto-merge.\n\nšŸš€' }) env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}