Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance]: Long AllReduce wait time on 1 device with tensor parallelism #5792

Closed
wenscarl opened this issue Jun 24, 2024 · 4 comments
Closed
Labels
performance Performance-related issues stale

Comments

@wenscarl
Copy link

wenscarl commented Jun 24, 2024

Proposal to improve performance

Propose synchronizing the broadcast of tensor_dict at the beginning of each decoding step or block the process after broadcast.

Report of performance regression

In the decoding stage, after matrix multiplications utilizing tensor parallelism, an all-reduce operation follows, which implicitly synchronizes the processes. However, the asynchronous broadcast of tensor dictionaries (code available here) at the start of each decoding step causes CUDA kernels to launch at quite different times across processes. This leads to the scenario depicted in the following image.
image (12) and image (13)
@youkaichao

Misc discussion on performance

No response

Your current environment (if you think it is necessary)

CUDA_VISIBLE_DEVICES=0,1,2,3 nsys profile -t cuda,nvtx python benchmarks/benchmark_throughput.py --model=meta-llama/Meta-Llama-3-70B-Instruct --quantization=fp8  --dataset=/workspace/sw3/vllm/ShareGPT_V3_unfiltered_cleaned_split.json --output-len=64 --num-prompts=50 --enforce-eager -tp=4
@wenscarl wenscarl added the performance Performance-related issues label Jun 24, 2024
@youkaichao
Copy link
Member

thanks for the report! We do plan to remove this broadcast call. you can track the progress at #6241 . once we solve that issue, the driver process will send a lightweight python object to all processes, and each process prepare input themselves, so we don't need the broadcast tensors.

@eileenzhujuan
Copy link

Proposal to improve performance

Propose synchronizing the broadcast of tensor_dict at the beginning of each decoding step or block the process after broadcast.

Report of performance regression

In the decoding stage, after matrix multiplications utilizing tensor parallelism, an all-reduce operation follows, which implicitly synchronizes the processes. However, the asynchronous broadcast of tensor dictionaries (code available here) at the start of each decoding step causes CUDA kernels to launch at quite different times across processes. This leads to the scenario depicted in the following image. image (12) and image (13) @youkaichao

Misc discussion on performance

No response

Your current environment (if you think it is necessary)

CUDA_VISIBLE_DEVICES=0,1,2,3 nsys profile -t cuda,nvtx python benchmarks/benchmark_throughput.py --model=meta-llama/Meta-Llama-3-70B-Instruct --quantization=fp8  --dataset=/workspace/sw3/vllm/ShareGPT_V3_unfiltered_cleaned_split.json --output-len=64 --num-prompts=50 --enforce-eager -tp=4

Hi, I am curious about this proposal, as I met the similar problem. When I set tp_size=4, one of the rank(not the tp_rank=0 one) would appear the phenomenon that kernel launch turns much slower. As a result, each attention become slower with an all_reduce at the end. So, you meant that make the broadcast at the beginning of each decode step synchronize immediately would relieve the problem. Is it?

Copy link

This issue has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this issue should remain open. Thank you!

@github-actions github-actions bot added the stale label Oct 25, 2024
Copy link

This issue has been automatically closed due to inactivity. Please feel free to reopen if you feel it is still relevant. Thank you!

@github-actions github-actions bot closed this as not planned Won't fix, can't repro, duplicate, stale Nov 24, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Performance-related issues stale
Projects
None yet
Development

No branches or pull requests

3 participants