Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend][V1] Online serving performance improvements #12287

Merged
merged 5 commits into from
Jan 22, 2025

Conversation

njhill
Copy link
Member

@njhill njhill commented Jan 21, 2025

These help in particular with TTFT, ITL variance, and overall throughput.

  • Break up output processing (detokenization) to avoid blocking the event loop for too long
  • Freeze the heap after startup to reduce GC overhead/pauses
  • Optimize a couple of CPU hotspots seen during profiling

Benchmark on A100:

VLLM_USE_V1=1 vllm serve meta-llama/Llama-3.2-1B-Instruct --disable-log-requests --port 8001 --max-num-batched-tokens 8192 --no-enable-prefix-caching --uvicorn-log-level=error
python benchmarks/benchmark_serving.py \
    --backend vllm \
    --model meta-llama/Llama-3.2-1B-Instruct \
    --dataset-name sharegpt \
    --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json \
    --ignore-eos \
    --port 8001 \
    --save-result \
    --result-dir results \
    --result-filename test.json \
    --num-prompts 6000 \
    --request-rate inf \
    --max-concurrency=400

Before:

============ Serving Benchmark Result ============
Successful requests:                     6000      
Benchmark duration (s):                  94.31     
Total input tokens:                      1350511   
Total generated tokens:                  1211959   
Request throughput (req/s):              63.62     
Output token throughput (tok/s):         12850.45  
Total Token throughput (tok/s):          27169.98  
---------------Time to First Token----------------
Mean TTFT (ms):                          229.23    
Median TTFT (ms):                        158.08    
P99 TTFT (ms):                           1050.70   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          30.02     
Median TPOT (ms):                        29.64     
P99 TPOT (ms):                           68.90     
---------------Inter-token Latency----------------
Mean ITL (ms):                           28.77     
Median ITL (ms):                         23.19     
P99 ITL (ms):                            386.30    
==================================================

After:

============ Serving Benchmark Result ============
Successful requests:                     6000      
Benchmark duration (s):                  88.60     
Total input tokens:                      1350511   
Total generated tokens:                  1211959   
Request throughput (req/s):              67.72     
Output token throughput (tok/s):         13679.34  
Total Token throughput (tok/s):          28922.50  
---------------Time to First Token----------------
Mean TTFT (ms):                          197.34    
Median TTFT (ms):                        168.03    
P99 TTFT (ms):                           1059.55   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          28.30     
Median TPOT (ms):                        27.75     
P99 TPOT (ms):                           47.38     
---------------Inter-token Latency----------------
Mean ITL (ms):                           26.64     
Median ITL (ms):                         24.38     
P99 ITL (ms):                            65.19     
==================================================

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@mergify mergify bot added the frontend label Jan 21, 2025
These help in particular with TTFT, and ITL variance. Overall throughput doesn't change much.

- Break up output processing (detokenization) to avoid blocking the event loop for too long
- Freeze the heap after startup to reduce GC overhead/pauses
- Optimize a couple of CPU hotspots seen during profiling

Signed-off-by: Nick Hill <[email protected]>
@njhill njhill force-pushed the v1-perf-smoothing branch from cfc5705 to 55dd119 Compare January 21, 2025 23:39
@@ -42,23 +42,31 @@ class OpenAIBaseModel(BaseModel):
# OpenAI API does allow extra fields
model_config = ConfigDict(extra="allow")

# Cache class field names
field_names: ClassVar[Optional[Set[str]]] = None
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was noticeable overhead creating this set every time one of these objects is instantiated.

def output_token_ids(self) -> ConstantList[int]:
# Prevent directly appending to the output_token_ids since
# all_token_ids should also be updated simultaneously.
return ConstantList(self._output_token_ids)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid constructing these objects every time the properties are accessed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually thought properties were cached after the first call, nice call

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually thought properties were cached after the first call, nice call

That would involve the use of cached_property.

@robertgshaw2-redhat
Copy link
Collaborator

Wow, the impact on P99 ITL is crazy.

# Mark the startup heap as static so that it's ignored by GC.
# Reduces pause times of oldest generation collections.
gc.collect()
gc.freeze()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to call unfreeze at some point?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this is mostly static stuff that will be around for the lifetime of the process anyhow.

https://www.rippling.com/blog/the-garbage-collector-fights-back

@njhill
Copy link
Member Author

njhill commented Jan 22, 2025

Combining with #12298 and increasing the max output processing chunk size to 256 gets higher throughput at the cost of slightly more latency variance.

Since the benchmark I've been running is 400 concurrent requests, the 256 chunk size essentially just means those will be split into two chunks of ~400. If I disable the chunking completely, the throughput increases to 80 req/sec (with the coalescing), but the inter-response latencies become larger and more uneven.

============ Serving Benchmark Result ============
Successful requests:                     6000      
Benchmark duration (s):                  84.70     
Total input tokens:                      1350511   
Total generated tokens:                  1211959   
Request throughput (req/s):              70.84     
Output token throughput (tok/s):         14308.94  
Total Token throughput (tok/s):          30253.69  
---------------Time to First Token----------------
Mean TTFT (ms):                          198.28    
Median TTFT (ms):                        166.40    
P99 TTFT (ms):                           1128.75   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          26.76     
Median TPOT (ms):                        26.05     
P99 TPOT (ms):                           50.04     
---------------Inter-token Latency----------------
Mean ITL (ms):                           29.41     
Median ITL (ms):                         26.83     
P99 ITL (ms):                            75.34     
==================================================

@njhill
Copy link
Member Author

njhill commented Jan 22, 2025

It would probably be good to also make OUTPUT_PROCESSING_CHUNK_SIZE overridable via an env var.

vllm/v1/engine/output_processor.py Show resolved Hide resolved
def output_token_ids(self) -> ConstantList[int]:
# Prevent directly appending to the output_token_ids since
# all_token_ids should also be updated simultaneously.
return ConstantList(self._output_token_ids)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually thought properties were cached after the first call, nice call

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

I ran an lm-eval test with gsm8k as a smoke test and got the same result as v0

VLLM_USE_V1=1 vllm serve meta-llama/Llama-3.1-8B-Instruct --disable-log-requests --port 8000 --max-num-batched-tokens 8192 --no-enable-prefix-caching

lm_eval --model local-completions --model_args model=meta-llama/Llama-3.1-8B-Instruct,base_url=http://0.0.0.0:8000/v1/completions,num_concurrent=50,tokenized_requests=False --tasks gsm8k --num_fewshot 5
local-completions (model=meta-llama/Llama-3.1-8B-Instruct,base_url=http://0.0.0.0:8000/v1/completions,num_concurrent=50,tokenized_requests=False), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.7718|±  |0.0116|
|     |       |strict-match    |     5|exact_match|↑  |0.6983|±  |0.0126|

@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Jan 22, 2025
Copy link

mergify bot commented Jan 22, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @njhill.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jan 22, 2025
@mergify mergify bot removed the needs-rebase label Jan 22, 2025
@mgoin mgoin enabled auto-merge (squash) January 22, 2025 22:18
@mgoin mgoin merged commit aea9436 into vllm-project:main Jan 22, 2025
51 checks passed
@njhill njhill deleted the v1-perf-smoothing branch January 22, 2025 23:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
frontend ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants