From a0550cbc80f504aa2da80b573c22204f686a0389 Mon Sep 17 00:00:00 2001 From: "Kevin H. Luu" Date: Tue, 9 Jul 2024 12:56:56 -0700 Subject: [PATCH 01/26] Add support for multi-node on CI (#5955) Signed-off-by: kevin --- .buildkite/run-multi-node-test.sh | 77 +++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) create mode 100755 .buildkite/run-multi-node-test.sh diff --git a/.buildkite/run-multi-node-test.sh b/.buildkite/run-multi-node-test.sh new file mode 100755 index 0000000000000..0d94b2555f166 --- /dev/null +++ b/.buildkite/run-multi-node-test.sh @@ -0,0 +1,77 @@ +#!/bin/bash + +set -euox pipefail + +if [[ $# -lt 3 ]]; then + echo "Please provide the number of nodes and GPU per node." + exit 1 +fi + +NUM_NODES=$1 +NUM_GPUS=$2 +DOCKER_IMAGE=$3 + +shift 3 +COMMANDS=("$@") +if [ ${#COMMANDS[@]} -ne $NUM_NODES ]; then + echo "The number of commands must be equal to the number of nodes." + echo "Number of nodes: $NUM_NODES" + echo "Number of commands: ${#COMMANDS[@]}" + exit 1 +fi + +echo "List of commands" +for command in "${COMMANDS[@]}"; do + echo $command +done + +start_network() { + docker network create --subnet=192.168.10.0/24 docker-net +} + +start_nodes() { + for node in $(seq 0 $(($NUM_NODES-1))); do + GPU_DEVICES='"device=' + for node_gpu in $(seq 0 $(($NUM_GPUS - 1))); do + DEVICE_NUM=$(($node * $NUM_GPUS + $node_gpu)) + GPU_DEVICES+=$(($DEVICE_NUM)) + if [ $node_gpu -lt $(($NUM_GPUS - 1)) ]; then + GPU_DEVICES+=',' + fi + done + GPU_DEVICES+='"' + # echo "Starting node$node with GPU devices: $GPU_DEVICES" + docker run -d --gpus "$GPU_DEVICES" --name node$node --network docker-net --ip 192.168.10.$((10 + $node)) --rm $DOCKER_IMAGE tail -f /dev/null + done +} + +run_nodes() { + for node in $(seq 0 $(($NUM_NODES-1))); do + GPU_DEVICES='"device=' + for node_gpu in $(seq 0 $(($NUM_GPUS - 1))); do + DEVICE_NUM=$(($node * $NUM_GPUS + $node_gpu)) + GPU_DEVICES+=$(($DEVICE_NUM)) + if [ $node_gpu -lt $(($NUM_GPUS - 1)) ]; then + GPU_DEVICES+=',' + fi + done + GPU_DEVICES+='"' + echo "Running node$node with GPU devices: $GPU_DEVICES" + if [ $node -lt $(($NUM_NODES - 1)) ]; then + docker exec -d node$node /bin/bash -c "${COMMANDS[$node]}" + else + docker exec node$node /bin/bash -c "${COMMANDS[$node]}" + fi + done +} +cleanup() { + for node in $(seq 0 $(($NUM_NODES-1))); do + docker stop node$node + done + docker network rm docker-net +} +trap cleanup EXIT +start_network +start_nodes +run_nodes + From 4d6ada947c7e6379b6857bc9a9a1203679d32039 Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Tue, 9 Jul 2024 16:26:36 -0400 Subject: [PATCH 02/26] [CORE] Adding support for insertion of soft-tuned prompts (#4645) Co-authored-by: Swapnil Parekh Co-authored-by: Joe G Co-authored-by: Antoni Baum --- format.sh | 1 + tests/lora/test_long_context.py | 9 +- tests/lora/test_lora_manager.py | 326 ++++++++-------- tests/prompt_adapter/test_bloom.py | 45 +++ .../test_multi_adapter_inference.py | 53 +++ tests/prompt_adapter/test_pa_lora.py | 61 +++ tests/spec_decode/e2e/conftest.py | 2 + tests/worker/test_model_runner.py | 1 + vllm/adapter_commons/__init__.py | 0 vllm/adapter_commons/layers.py | 14 + vllm/adapter_commons/models.py | 104 +++++ vllm/adapter_commons/request.py | 25 ++ vllm/adapter_commons/utils.py | 90 +++++ vllm/adapter_commons/worker_manager.py | 36 ++ vllm/config.py | 37 ++ vllm/core/scheduler.py | 12 + vllm/engine/arg_utils.py | 24 +- vllm/engine/async_llm_engine.py | 38 +- vllm/engine/llm_engine.py | 65 +++- vllm/entrypoints/llm.py | 33 +- vllm/entrypoints/openai/api_server.py | 5 +- vllm/entrypoints/openai/cli_args.py | 21 +- vllm/entrypoints/openai/serving_chat.py | 2 +- vllm/entrypoints/openai/serving_completion.py | 17 +- vllm/entrypoints/openai/serving_engine.py | 61 ++- vllm/executor/cpu_executor.py | 15 + vllm/executor/executor_base.py | 27 +- vllm/executor/gpu_executor.py | 21 ++ vllm/executor/ray_xpu_executor.py | 5 +- vllm/executor/xpu_executor.py | 5 +- vllm/lora/layers.py | 12 +- vllm/lora/models.py | 175 ++++----- vllm/lora/request.py | 25 +- vllm/lora/worker_manager.py | 215 ++++------- vllm/prompt_adapter/__init__.py | 0 vllm/prompt_adapter/layers.py | 80 ++++ vllm/prompt_adapter/models.py | 355 ++++++++++++++++++ vllm/prompt_adapter/request.py | 30 ++ vllm/prompt_adapter/worker_manager.py | 176 +++++++++ vllm/sequence.py | 48 ++- vllm/spec_decode/draft_model_runner.py | 11 +- vllm/worker/cpu_model_runner.py | 4 +- vllm/worker/cpu_worker.py | 5 +- vllm/worker/embedding_model_runner.py | 11 +- vllm/worker/model_runner.py | 138 ++++++- vllm/worker/worker.py | 20 +- vllm/worker/xpu_model_runner.py | 4 +- vllm/worker/xpu_worker.py | 5 +- 48 files changed, 1951 insertions(+), 518 deletions(-) create mode 100644 tests/prompt_adapter/test_bloom.py create mode 100644 tests/prompt_adapter/test_multi_adapter_inference.py create mode 100644 tests/prompt_adapter/test_pa_lora.py create mode 100644 vllm/adapter_commons/__init__.py create mode 100644 vllm/adapter_commons/layers.py create mode 100644 vllm/adapter_commons/models.py create mode 100644 vllm/adapter_commons/request.py create mode 100644 vllm/adapter_commons/utils.py create mode 100644 vllm/adapter_commons/worker_manager.py create mode 100644 vllm/prompt_adapter/__init__.py create mode 100644 vllm/prompt_adapter/layers.py create mode 100644 vllm/prompt_adapter/models.py create mode 100644 vllm/prompt_adapter/request.py create mode 100644 vllm/prompt_adapter/worker_manager.py diff --git a/format.sh b/format.sh index 8c54b56302d5b..5edc868f9f70c 100755 --- a/format.sh +++ b/format.sh @@ -111,6 +111,7 @@ mypy vllm/spec_decode --config-file pyproject.toml mypy vllm/model_executor --config-file pyproject.toml mypy vllm/lora --config-file pyproject.toml mypy vllm/logging --config-file pyproject.toml +mypy vllm/prompt_adapter --config-file pyproject.toml mypy tests --config-file pyproject.toml diff --git a/tests/lora/test_long_context.py b/tests/lora/test_long_context.py index b50784a205af7..853fd9fb3ce7a 100644 --- a/tests/lora/test_long_context.py +++ b/tests/lora/test_long_context.py @@ -92,11 +92,10 @@ def batched_generate( for input in inputs: prompt, sampling_param, lora_req = input # Add requests to the engine and run the engine - llm._validate_and_add_requests( - prompt, - sampling_param, - lora_request=lora_req, - ) + llm._validate_and_add_requests(prompt, + sampling_param, + lora_request=lora_req, + prompt_adapter_request=None) outputs = llm._run_engine(use_tqdm=True) return [outputs[i].outputs[0].text.strip() for i in range(len(outputs))] diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index 2133bce14957b..7bff9e1fbcdcc 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -127,37 +127,37 @@ def test_lora_model_manager(dist_init, dummy_model): model, 2, 2, 2, LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2)) assert all(x is None for x in manager.lora_index_to_id) - assert manager.add_lora(model_lora1) - assert manager.activate_lora(1) + assert manager.add_adapter(model_lora1) + assert manager.activate_adapter(1) assert manager.lora_index_to_id[0] == 1 - assert not manager.add_lora(model_lora1) - assert not manager.activate_lora(1) - assert manager.add_lora(model_lora2) - assert manager.activate_lora(2) + assert not manager.add_adapter(model_lora1) + assert not manager.activate_adapter(1) + assert manager.add_adapter(model_lora2) + assert manager.activate_adapter(2) assert manager.lora_index_to_id[0] == 1 assert manager.lora_index_to_id[1] == 2 - assert not manager.add_lora(model_lora2) - assert not manager.activate_lora(2) - assert manager.add_lora(model_lora3) + assert not manager.add_adapter(model_lora2) + assert not manager.activate_adapter(2) + assert manager.add_adapter(model_lora3) assert manager.lora_index_to_id[0] == 1 assert manager.lora_index_to_id[1] == 2 with pytest.raises(ValueError): - assert manager.activate_lora(3) + assert manager.activate_adapter(3) assert manager.lora_index_to_id[0] == 1 assert manager.lora_index_to_id[1] == 2 - assert manager.remove_lora(model_lora2.id) + assert manager.remove_adapter(model_lora2.id) assert manager.lora_index_to_id[1] is None - assert not manager.remove_lora(model_lora2.id) - assert manager.remove_lora(model_lora1.id) - assert not manager.remove_lora(model_lora1.id) - assert manager.add_lora(model_lora1) + assert not manager.remove_adapter(model_lora2.id) + assert manager.remove_adapter(model_lora1.id) + assert not manager.remove_adapter(model_lora1.id) + assert manager.add_adapter(model_lora1) assert manager.lora_index_to_id[0] is None assert manager.lora_index_to_id[1] is None - assert manager.add_lora(model_lora2) - assert manager.activate_lora(3) + assert manager.add_adapter(model_lora2) + assert manager.activate_adapter(3) assert manager.lora_index_to_id[0] == 3 assert manager.lora_index_to_id[1] is None - assert manager.activate_lora(2) + assert manager.activate_adapter(2) assert manager.lora_index_to_id[0] == 3 assert manager.lora_index_to_id[1] == 2 @@ -173,70 +173,70 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model): model, 2, 2, 2, LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2)) assert all(x is None for x in manager.lora_index_to_id) - assert manager.add_lora(model_lora1) - assert manager.activate_lora(1) + assert manager.add_adapter(model_lora1) + assert manager.activate_adapter(1) assert manager.lora_index_to_id[0] == 1 - assert not manager.add_lora(model_lora1) - assert not manager.activate_lora(1) - assert manager.add_lora(model_lora2) - assert manager.activate_lora(2) + assert not manager.add_adapter(model_lora1) + assert not manager.activate_adapter(1) + assert manager.add_adapter(model_lora2) + assert manager.activate_adapter(2) assert manager.lora_index_to_id[0] == 1 assert manager.lora_index_to_id[1] == 2 - assert not manager.add_lora(model_lora2) - assert not manager.activate_lora(2) - assert manager.add_lora(model_lora3) + assert not manager.add_adapter(model_lora2) + assert not manager.activate_adapter(2) + assert manager.add_adapter(model_lora3) assert manager.lora_index_to_id[0] == 1 assert manager.lora_index_to_id[1] == 2 - assert manager.activate_lora(3) + assert manager.activate_adapter(3) assert manager.lora_index_to_id[0] == 3 assert manager.lora_index_to_id[1] == 2 - assert manager.remove_lora(model_lora2.id) + assert manager.remove_adapter(model_lora2.id) assert manager.lora_index_to_id[1] is None - assert not manager.remove_lora(model_lora2.id) - assert manager.remove_lora(model_lora1.id) - assert not manager.remove_lora(model_lora1.id) - assert manager.add_lora(model_lora1) - assert manager.activate_lora(1) + assert not manager.remove_adapter(model_lora2.id) + assert manager.remove_adapter(model_lora1.id) + assert not manager.remove_adapter(model_lora1.id) + assert manager.add_adapter(model_lora1) + assert manager.activate_adapter(1) assert manager.lora_index_to_id[0] == 3 assert manager.lora_index_to_id[1] == 1 - assert manager.add_lora(model_lora2) - assert manager.deactivate_lora(3) + assert manager.add_adapter(model_lora2) + assert manager.deactivate_adapter(3) assert manager.lora_index_to_id[0] is None assert manager.lora_index_to_id[1] == 1 - assert manager.activate_lora(2) + assert manager.activate_adapter(2) assert manager.lora_index_to_id[0] == 2 assert manager.lora_index_to_id[1] == 1 - assert manager.activate_lora(3) + assert manager.activate_adapter(3) assert manager.lora_index_to_id[0] == 2 assert manager.lora_index_to_id[1] == 3 - assert manager.pin_lora(2) + assert manager.pin_adapter(2) assert manager.lora_index_to_id[0] == 2 assert manager.lora_index_to_id[1] == 3 - assert manager.activate_lora(1) + assert manager.activate_adapter(1) assert manager.lora_index_to_id[0] == 2 assert manager.lora_index_to_id[1] == 1 - assert manager.deactivate_lora(2) + assert manager.deactivate_adapter(2) assert manager.lora_index_to_id[0] is None assert manager.lora_index_to_id[1] == 1 - assert manager.activate_lora(3) + assert manager.activate_adapter(3) assert manager.lora_index_to_id[0] == 3 assert manager.lora_index_to_id[1] == 1 - assert manager.pin_lora(3) - assert manager.pin_lora(1) + assert manager.pin_adapter(3) + assert manager.pin_adapter(1) with pytest.raises(RuntimeError): - assert manager.pin_lora(2) + assert manager.pin_adapter(2) assert manager.lora_index_to_id[0] == 3 assert manager.lora_index_to_id[1] == 1 with pytest.raises(RuntimeError): - assert manager.activate_lora(2) + assert manager.activate_adapter(2) - assert manager.deactivate_lora(3) - assert manager.pin_lora(2) + assert manager.deactivate_adapter(3) + assert manager.pin_adapter(2) assert manager.lora_index_to_id[0] == 2 assert manager.lora_index_to_id[1] == 1 - assert manager.remove_lora(3) + assert manager.remove_adapter(3) with pytest.raises(ValueError): - assert manager.pin_lora(3) + assert manager.pin_adapter(3) def test_lru_lora_model_manager(dist_init, dummy_model): @@ -256,168 +256,169 @@ def test_lru_lora_model_manager(dist_init, dummy_model): assert all(x is None for x in manager.lora_index_to_id) # Add up to capacity - assert manager.add_lora(model_lora1) - assert manager.add_lora(model_lora2) - assert manager.activate_lora(1) - assert manager.activate_lora(2) + assert manager.add_adapter(model_lora1) + assert manager.add_adapter(model_lora2) + assert manager.activate_adapter(1) + assert manager.activate_adapter(2) - assert set(manager.list_loras()) == {1, 2} + assert set(manager.list_adapters()) == {1, 2} assert manager.lora_index_to_id[0] == 1 assert manager.lora_index_to_id[1] == 2 # Add over capacity - assert manager.add_lora(model_lora3) - assert manager.add_lora(model_lora4) - assert manager.activate_lora(3) - assert manager.activate_lora(4) + assert manager.add_adapter(model_lora3) + assert manager.add_adapter(model_lora4) + assert manager.activate_adapter(3) + assert manager.activate_adapter(4) - assert set(manager.list_loras()) == {3, 4} + assert set(manager.list_adapters()) == {3, 4} assert manager.lora_index_to_id[0] == 3 assert manager.lora_index_to_id[1] == 4 # Add 3 again to move it to the top and then add 2 # should return false since it's in already - assert not manager.add_lora(model_lora3) - assert not manager.activate_lora(3) - assert manager.add_lora(model_lora2) - assert manager.activate_lora(2) + assert not manager.add_adapter(model_lora3) + assert not manager.activate_adapter(3) + assert manager.add_adapter(model_lora2) + assert manager.activate_adapter(2) - assert set(manager.list_loras()) == {3, 2} + assert set(manager.list_adapters()) == {3, 2} assert manager.lora_index_to_id[0] == 3 assert manager.lora_index_to_id[1] == 2 # Remove manually - assert manager.remove_lora(3) - assert not manager.remove_lora(3) + assert manager.remove_adapter(3) + assert not manager.remove_adapter(3) - assert set(manager.list_loras()) == {2} + assert set(manager.list_adapters()) == {2} assert manager.lora_index_to_id[0] is None assert manager.lora_index_to_id[1] == 2 - assert manager.add_lora(model_lora3) - assert manager.activate_lora(3) - assert manager.add_lora(model_lora4) - assert manager.activate_lora(4) + assert manager.add_adapter(model_lora3) + assert manager.activate_adapter(3) + assert manager.add_adapter(model_lora4) + assert manager.activate_adapter(4) - assert set(manager.list_loras()) == {3, 4} + assert set(manager.list_adapters()) == {3, 4} assert manager.lora_index_to_id[0] == 3 assert manager.lora_index_to_id[1] == 4 - assert manager.remove_oldest_lora() - assert set(manager.list_loras()) == {4} + assert manager.remove_oldest_adapter() + assert set(manager.list_adapters()) == {4} assert manager.lora_index_to_id[0] is None assert manager.lora_index_to_id[1] == 4 - assert manager.remove_oldest_lora() - assert set(manager.list_loras()) == set() + assert manager.remove_oldest_adapter() + assert set(manager.list_adapters()) == set() assert all(x is None for x in manager.lora_index_to_id) - assert not manager.remove_oldest_lora() - assert set(manager.list_loras()) == set() + assert not manager.remove_oldest_adapter() + assert set(manager.list_adapters()) == set() assert all(x is None for x in manager.lora_index_to_id) # pinning - assert manager.add_lora(model_lora3) - assert manager.activate_lora(3) - assert manager.add_lora(model_lora4) - assert manager.activate_lora(4) - assert set(manager.list_loras()) == {3, 4} + assert manager.add_adapter(model_lora3) + assert manager.activate_adapter(3) + assert manager.add_adapter(model_lora4) + assert manager.activate_adapter(4) + assert set(manager.list_adapters()) == {3, 4} with pytest.raises(ValueError): - assert manager.pin_lora(1) - assert manager.pin_lora(3) + assert manager.pin_adapter(1) + assert manager.pin_adapter(3) # Remove manually - assert manager.remove_lora(3) - assert not manager.remove_lora(3) + assert manager.remove_adapter(3) + assert not manager.remove_adapter(3) - assert set(manager.list_loras()) == {4} + assert set(manager.list_adapters()) == {4} assert manager.lora_index_to_id[0] is None assert manager.lora_index_to_id[1] == 4 - assert manager.add_lora(model_lora1) - assert manager.pin_lora(1) - assert manager.add_lora(model_lora2) - assert manager.activate_lora(2) + assert manager.add_adapter(model_lora1) + assert manager.pin_adapter(1) + assert manager.add_adapter(model_lora2) + assert manager.activate_adapter(2) - assert set(manager.list_loras()) == {1, 2} + assert set(manager.list_adapters()) == {1, 2} assert manager.lora_index_to_id[0] == 1 assert manager.lora_index_to_id[1] == 2 - assert manager.remove_oldest_lora() - assert set(manager.list_loras()) == {1} + assert manager.remove_oldest_adapter() + assert set(manager.list_adapters()) == {1} assert manager.lora_index_to_id[0] == 1 assert manager.lora_index_to_id[1] is None with pytest.raises(RuntimeError): - assert manager.remove_oldest_lora() + assert manager.remove_oldest_adapter() - assert set(manager.list_loras()) == {1} + assert set(manager.list_adapters()) == {1} -def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings, - sql_lora_files): +def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings, + sql_lora_files): lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4) - worker_lora_manager = LRUCacheWorkerLoRAManager( + worker_adapter_manager = LRUCacheWorkerLoRAManager( 4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size - lora_config.lora_extra_vocab_size, lora_config, torch.device("cuda"), EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES) - worker_lora_manager.create_lora_manager(llama_2_7b_model_extra_embeddings) + worker_adapter_manager.create_lora_manager( + llama_2_7b_model_extra_embeddings) mapping = LoRAMapping([], []) - worker_lora_manager.set_active_loras([ + worker_adapter_manager.set_active_adapters([ LoRARequest("1", 1, sql_lora_files), LoRARequest("2", 2, sql_lora_files) ], mapping) - assert worker_lora_manager.list_loras() == {1, 2} - assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 - assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2 + assert worker_adapter_manager.list_adapters() == {1, 2} + assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2 - worker_lora_manager.set_active_loras([ + worker_adapter_manager.set_active_adapters([ LoRARequest("1", 1, sql_lora_files), LoRARequest("3", 3, sql_lora_files), LoRARequest("4", 4, sql_lora_files) ], mapping) - assert worker_lora_manager.list_loras() == {1, 2, 3, 4} - assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 - assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2 - assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 3 - assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 4 + assert worker_adapter_manager.list_adapters() == {1, 2, 3, 4} + assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 3 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4 - worker_lora_manager.set_active_loras([ + worker_adapter_manager.set_active_adapters([ LoRARequest("1", 1, sql_lora_files), LoRARequest("2", 2, sql_lora_files), LoRARequest("5", 5, sql_lora_files) ], mapping) - assert worker_lora_manager.list_loras() == {1, 2, 4, 5} - assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 - assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2 - assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 5 - assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 4 + assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5} + assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4 - worker_lora_manager.set_active_loras([ + worker_adapter_manager.set_active_adapters([ LoRARequest("1", 1, sql_lora_files), LoRARequest("1", 1, sql_lora_files), LoRARequest("1", 1, sql_lora_files) ], mapping) - assert worker_lora_manager.list_loras() == {1, 2, 4, 5} - assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 - assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2 - assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 5 - assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 4 + assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5} + assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4 - worker_lora_manager.set_active_loras([ + worker_adapter_manager.set_active_adapters([ LoRARequest("6", 6, sql_lora_files), LoRARequest("7", 7, sql_lora_files), LoRARequest("8", 8, sql_lora_files) ], mapping) - assert worker_lora_manager.list_loras() == {1, 6, 7, 8} - assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 - assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 7 - assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 8 - assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 6 + assert worker_adapter_manager.list_adapters() == {1, 6, 7, 8} + assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 7 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 8 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 6 # Over capacity with pytest.raises(RuntimeError): - worker_lora_manager.set_active_loras([ + worker_adapter_manager.set_active_adapters([ LoRARequest("10", 10, sql_lora_files), LoRARequest("11", 11, sql_lora_files), LoRARequest("12", 12, sql_lora_files), @@ -426,68 +427,69 @@ def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings, ], mapping) -def test_worker_lora_manager(llama_2_7b_model_extra_embeddings, - sql_lora_files): +def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings, + sql_lora_files): # Should remove every LoRA not specified in the request. lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4) - worker_lora_manager = WorkerLoRAManager( + worker_adapter_manager = WorkerLoRAManager( 4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size - lora_config.lora_extra_vocab_size, lora_config, torch.device("cuda"), EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES) - worker_lora_manager.create_lora_manager(llama_2_7b_model_extra_embeddings) + worker_adapter_manager.create_lora_manager( + llama_2_7b_model_extra_embeddings) mapping = LoRAMapping([], []) - worker_lora_manager.set_active_loras([ + worker_adapter_manager.set_active_adapters([ LoRARequest("1", 1, sql_lora_files), LoRARequest("2", 2, sql_lora_files) ], mapping) - assert worker_lora_manager.list_loras() == {1, 2} - assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 - assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2 + assert worker_adapter_manager.list_adapters() == {1, 2} + assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2 - worker_lora_manager.set_active_loras([ + worker_adapter_manager.set_active_adapters([ LoRARequest("1", 1, sql_lora_files), LoRARequest("3", 3, sql_lora_files), LoRARequest("4", 4, sql_lora_files) ], mapping) - assert worker_lora_manager.list_loras() == {1, 3, 4} - assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 - assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 3 - assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 4 + assert worker_adapter_manager.list_adapters() == {1, 3, 4} + assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 3 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 4 - worker_lora_manager.set_active_loras([ + worker_adapter_manager.set_active_adapters([ LoRARequest("1", 1, sql_lora_files), LoRARequest("2", 2, sql_lora_files), LoRARequest("5", 5, sql_lora_files) ], mapping) - assert worker_lora_manager.list_loras() == {1, 2, 5} - assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 - assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2 - assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 5 + assert worker_adapter_manager.list_adapters() == {1, 2, 5} + assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5 - worker_lora_manager.set_active_loras([ + worker_adapter_manager.set_active_adapters([ LoRARequest("1", 1, sql_lora_files), LoRARequest("1", 1, sql_lora_files), LoRARequest("1", 1, sql_lora_files) ], mapping) - assert worker_lora_manager.list_loras() == {1} - assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 - assert worker_lora_manager._lora_manager.lora_index_to_id[1] is None - assert worker_lora_manager._lora_manager.lora_index_to_id[2] is None + assert worker_adapter_manager.list_adapters() == {1} + assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] is None + assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] is None - worker_lora_manager.set_active_loras([ + worker_adapter_manager.set_active_adapters([ LoRARequest("6", 6, sql_lora_files), LoRARequest("7", 7, sql_lora_files), LoRARequest("8", 8, sql_lora_files) ], mapping) - assert worker_lora_manager.list_loras() == {6, 7, 8} - assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 8 - assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 6 - assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 7 + assert worker_adapter_manager.list_adapters() == {6, 7, 8} + assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 8 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 6 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 7 # Over capacity with pytest.raises(RuntimeError): - worker_lora_manager.set_active_loras([ + worker_adapter_manager.set_active_adapters([ LoRARequest("10", 10, sql_lora_files), LoRARequest("11", 11, sql_lora_files), LoRARequest("12", 12, sql_lora_files), @@ -525,8 +527,8 @@ def test_packed_loras(dist_init, dummy_model_gate_up): assert isinstance(model.get_submodule("gate_up_proj"), MergedColumnParallelLinearWithLoRA) - assert manager.add_lora(model_lora) - assert manager.add_lora(model_lora1) + assert manager.add_adapter(model_lora) + assert manager.add_adapter(model_lora1) packed_lora = model_lora.get_lora("gate_up_proj") assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights) diff --git a/tests/prompt_adapter/test_bloom.py b/tests/prompt_adapter/test_bloom.py new file mode 100644 index 0000000000000..6528b3009b8c0 --- /dev/null +++ b/tests/prompt_adapter/test_bloom.py @@ -0,0 +1,45 @@ +import pytest + +import vllm +from vllm.prompt_adapter.request import PromptAdapterRequest + +MODEL_PATH = "bigscience/bloomz-560m" +PA_PATH = 'stevhliu/bloomz-560m_PROMPT_TUNING_CAUSAL_LM' + + +def do_sample(llm, pa_name: str, pa_id: int): + + prompts = [ + "Tweet text : @nationalgridus I have no water and the bill is \ + current and paid. Can you do something about this? Label : ", + "Tweet text : @nationalgridus Looks good thanks! Label : " + ] + sampling_params = vllm.SamplingParams(temperature=0.0, + max_tokens=3, + stop_token_ids=[3]) + + outputs = llm.generate(prompts, + sampling_params, + prompt_adapter_request=PromptAdapterRequest( + pa_name, pa_id, PA_PATH, 8) if pa_id else None) + + # Print the outputs. + generated_texts = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text.strip() + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + return generated_texts + + +@pytest.mark.parametrize("enforce_eager", [True, False]) +def test_twitter_prompt_adapter(enforce_eager: bool): + llm = vllm.LLM(MODEL_PATH, + enforce_eager=enforce_eager, + enable_prompt_adapter=True, + max_prompt_adapter_token=8) + + expected_output = ['complaint', 'no complaint'] + + assert do_sample(llm, "twitter_pa", pa_id=1) == expected_output diff --git a/tests/prompt_adapter/test_multi_adapter_inference.py b/tests/prompt_adapter/test_multi_adapter_inference.py new file mode 100644 index 0000000000000..39a79becdfbb3 --- /dev/null +++ b/tests/prompt_adapter/test_multi_adapter_inference.py @@ -0,0 +1,53 @@ +from vllm import EngineArgs, LLMEngine, SamplingParams +from vllm.prompt_adapter.request import PromptAdapterRequest + +MODEL_PATH = "bigscience/bloomz-560m" +pa_path = 'stevhliu/bloomz-560m_PROMPT_TUNING_CAUSAL_LM' +pa_path2 = 'swapnilbp/angry_tweet_ptune' + + +def do_sample(engine): + + prompts = [ + ("Tweet text: I have complaints! Label: ", + SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]), + PromptAdapterRequest("hate_speech", 1, pa_path2, 8)), + ("Tweet text: I have no problems Label: ", + SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]), + PromptAdapterRequest("hate_speech2", 2, pa_path2, 8)), + ("Tweet text: I have complaints! Label: ", + SamplingParams(temperature=0.0, max_tokens=3), None), + ("Tweet text: I have no problems Label: ", + SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]), + PromptAdapterRequest("complain", 3, pa_path, 8)), + ] + + request_id = 0 + results = set() + while prompts or engine.has_unfinished_requests(): + if prompts: + prompt, sampling_params, pa_request = prompts.pop(0) + engine.add_request(str(request_id), + prompt, + sampling_params, + prompt_adapter_request=pa_request) + request_id += 1 + + request_outputs = engine.step() + + for request_output in request_outputs: + if request_output.finished: + results.add(request_output.outputs[0].text) + return results + + +def test_multi_prompt_adapters(): + engine_args = EngineArgs(model=MODEL_PATH, + max_prompt_adapters=3, + enable_prompt_adapter=True, + max_prompt_adapter_token=8) + engine = LLMEngine.from_engine_args(engine_args) + expected_output = { + ' quot;I', 'hate speech', 'no complaint', 'not hate speech' + } + assert do_sample(engine) == expected_output diff --git a/tests/prompt_adapter/test_pa_lora.py b/tests/prompt_adapter/test_pa_lora.py new file mode 100644 index 0000000000000..2a5f23f7f92ec --- /dev/null +++ b/tests/prompt_adapter/test_pa_lora.py @@ -0,0 +1,61 @@ +from huggingface_hub import snapshot_download + +from vllm import EngineArgs, LLMEngine, SamplingParams +from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest + +MODEL_PATH = "meta-llama/Llama-2-7b-hf" +pa_path = snapshot_download(repo_id="swapnilbp/llama_tweet_ptune") +lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test") + + +def do_sample(engine): + + prompt_text = "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]" # noqa: E501 + + # first prompt with a prompt adapter and second without adapter + prompts = [ + (prompt_text, + SamplingParams(temperature=0.0, max_tokens=100, + stop=["[/assistant]"]), + PromptAdapterRequest("hate_speech", 1, pa_path, + 8), LoRARequest("sql_test", 1, lora_path)), + (prompt_text, + SamplingParams(temperature=0.0, max_tokens=100, + stop=["[/assistant]"]), None, + LoRARequest("sql_test", 1, lora_path)), + ] + + request_id = 0 + results = set() + while prompts or engine.has_unfinished_requests(): + if prompts: + prompt, sampling_params, pa_request, lora_request = prompts.pop(0) + engine.add_request(str(request_id), + prompt, + sampling_params, + prompt_adapter_request=pa_request, + lora_request=lora_request) + request_id += 1 + + request_outputs = engine.step() + + for request_output in request_outputs: + if request_output.finished: + results.add(request_output.outputs[0].text) + return results + + +def test_lora_prompt_adapter(): + engine_args = EngineArgs(model=MODEL_PATH, + enable_prompt_adapter=True, + enable_lora=True, + max_num_seqs=60, + max_prompt_adapter_token=8) + engine = LLMEngine.from_engine_args(engine_args) + result = do_sample(engine) + + expected_output = { + " SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' " # noqa: E501 + } + assert result == expected_output diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index 8ad8e9cb81ff8..fb3415b5db153 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -13,6 +13,7 @@ from vllm.model_executor.utils import set_random_seed from vllm.multimodal import MultiModalDataDict from vllm.outputs import RequestOutput +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.sequence import Logprob from vllm.usage.usage_lib import UsageContext @@ -92,6 +93,7 @@ def generate( use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, multi_modal_data: Optional[MultiModalDataDict] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> List[RequestOutput]: if prompts is None: diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index e1775790c0a03..b5742c4338616 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -23,6 +23,7 @@ def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner: cache_config=engine_config.cache_config, load_config=engine_config.load_config, lora_config=engine_config.lora_config, + prompt_adapter_config=engine_config.prompt_adapter_config, is_driver_worker=True, ) return model_runner diff --git a/vllm/adapter_commons/__init__.py b/vllm/adapter_commons/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/adapter_commons/layers.py b/vllm/adapter_commons/layers.py new file mode 100644 index 0000000000000..3ed60678b52f5 --- /dev/null +++ b/vllm/adapter_commons/layers.py @@ -0,0 +1,14 @@ +from dataclasses import dataclass +from typing import Tuple + + +@dataclass +class AdapterMapping: + # Per every token in input_ids: + index_mapping: Tuple[int, ...] + # Per sampled token: + prompt_mapping: Tuple[int, ...] + + def __post_init__(self): + self.index_mapping = tuple(self.index_mapping) + self.prompt_mapping = tuple(self.prompt_mapping) \ No newline at end of file diff --git a/vllm/adapter_commons/models.py b/vllm/adapter_commons/models.py new file mode 100644 index 0000000000000..6939b1405f3e1 --- /dev/null +++ b/vllm/adapter_commons/models.py @@ -0,0 +1,104 @@ +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, Hashable, Optional, TypeVar + +from torch import nn + +from vllm.logger import init_logger +from vllm.utils import LRUCache + +logger = init_logger(__name__) + + +class AdapterModel(ABC): + + def __init__(self, model_id=None): + self.id = model_id + + @abstractmethod + def from_local_checkpoint(cls, model_dir, model_id=None, **kwargs): + # Common initialization code + # Load weights or embeddings from local checkpoint + raise NotImplementedError("Subclasses must implement this method.") + + +T = TypeVar('T') + + +class AdapterLRUCache(LRUCache[T]): + + def __init__(self, capacity: int, deactivate_fn: Callable[[Hashable], + None]): + super().__init__(capacity) + self.deactivate_fn = deactivate_fn + + def _on_remove(self, key: Hashable, value: T): + logger.debug("Removing adapter int id: %d", key) + self.deactivate_fn(key) + return super()._on_remove(key, value) + + +class AdapterModelManager(ABC): + + def __init__( + self, + model: nn.Module, + ): + """Create a AdapterModelManager and adapter for a given model. + Args: + model: the model to be adapted. + """ + self.model: nn.Module = model + self._registered_adapters: Dict[int, Any] = {} + # Dict instead of a Set for compatibility with LRUCache. + self._active_adapters: Dict[int, None] = {} + self.adapter_type = 'Adapter' + self._last_mapping = None + + def __len__(self) -> int: + return len(self._registered_adapters) + + @property + @abstractmethod + def adapter_slots(self): + ... + + @property + @abstractmethod + def capacity(self): + ... + + @abstractmethod + def activate_adapter(self, adapter_id: int) -> bool: + ... + + @abstractmethod + def deactivate_adapter(self, adapter_id: int) -> bool: + ... + + @abstractmethod + def add_adapter(self, adapter: Any) -> bool: + ... + + @abstractmethod + def set_adapter_mapping(self, mapping: Any) -> None: + ... + + @abstractmethod + def remove_adapter(self, adapter_id: int) -> bool: + ... + + @abstractmethod + def remove_all_adapters(self): + ... + + @abstractmethod + def get_adapter(self, adapter_id: int) -> Optional[Any]: + ... + + @abstractmethod + def list_adapters(self) -> Dict[int, Any]: + ... + + @abstractmethod + def pin_adapter(self, adapter_id: int) -> bool: + ... diff --git a/vllm/adapter_commons/request.py b/vllm/adapter_commons/request.py new file mode 100644 index 0000000000000..69775ab7d4548 --- /dev/null +++ b/vllm/adapter_commons/request.py @@ -0,0 +1,25 @@ +from abc import abstractmethod +from dataclasses import dataclass + + +@dataclass +class AdapterRequest: + """ + Base class for adapter requests. + """ + + @property + @abstractmethod + def adapter_id(self): + ... + + def __post_init__(self): + if self.adapter_id < 1: + raise ValueError(f"id must be > 0, got {self.adapter_id}") + + def __eq__(self, value: object) -> bool: + return isinstance( + value, self.__class__) and self.adapter_id == value.adapter_id + + def __hash__(self) -> int: + return hash(self.adapter_id) diff --git a/vllm/adapter_commons/utils.py b/vllm/adapter_commons/utils.py new file mode 100644 index 0000000000000..6c5411f7d3d5c --- /dev/null +++ b/vllm/adapter_commons/utils.py @@ -0,0 +1,90 @@ +from typing import Any, Callable, Dict, Optional, Set + + +## model functions +def deactivate_adapter(adapter_id: int, active_adapters: Dict[int, None], + deactivate_func: Callable) -> bool: + if adapter_id in active_adapters: + deactivate_func(adapter_id) + active_adapters.pop(adapter_id) + return True + return False + + +def add_adapter(adapter: Any, registered_adapters: Dict[int, Any], + capacity: int, add_func: Callable) -> bool: + if adapter.id not in registered_adapters: + if len(registered_adapters) >= capacity: + raise RuntimeError('No free adapter slots.') + add_func(adapter) + registered_adapters[adapter.id] = adapter + return True + return False + + +def set_adapter_mapping(mapping: Any, last_mapping: Any, + set_mapping_func: Callable) -> Any: + if last_mapping != mapping: + set_mapping_func(mapping) + return mapping + return last_mapping + + +def remove_adapter(adapter_id: int, registered_adapters: Dict[int, Any], + deactivate_func: Callable) -> bool: + deactivate_func(adapter_id) + return bool(registered_adapters.pop(adapter_id, None)) + + +def list_adapters(registered_adapters: Dict[int, Any]) -> Dict[int, Any]: + return dict(registered_adapters) + + +def get_adapter(adapter_id: int, + registered_adapters: Dict[int, Any]) -> Optional[Any]: + return registered_adapters.get(adapter_id, None) + + +## worker functions +def set_active_adapters_worker(requests: Set[Any], mapping: Optional[Any], + apply_adapters_func, + set_adapter_mapping_func) -> None: + apply_adapters_func(requests) + set_adapter_mapping_func(mapping) + + +def add_adapter_worker(adapter_request: Any, list_adapters_func, + load_adapter_func, add_adapter_func, + activate_adapter_func) -> bool: + if adapter_request.adapter_id in list_adapters_func(): + return False + loaded_adapter = load_adapter_func(adapter_request) + loaded = add_adapter_func(loaded_adapter) + activate_adapter_func(loaded_adapter.id) + return loaded + + +def apply_adapters_worker(adapter_requests: Set[Any], list_adapters_func, + adapter_slots: int, remove_adapter_func, + add_adapter_func) -> None: + models_that_exist = list_adapters_func() + models_map = { + adapter_request.adapter_id: adapter_request + for adapter_request in adapter_requests if adapter_request + } + if len(models_map) > adapter_slots: + raise RuntimeError( + f"Number of requested models ({len(models_map)}) is greater " + f"than the number of GPU model slots " + f"({adapter_slots}).") + new_models = set(models_map) + models_to_add = new_models - models_that_exist + models_to_remove = models_that_exist - new_models + for adapter_id in models_to_remove: + remove_adapter_func(adapter_id) + for adapter_id in models_to_add: + add_adapter_func(models_map[adapter_id]) + + +def list_adapters_worker(adapter_manager_list_adapters_func) -> Set[int]: + return set(adapter_manager_list_adapters_func()) diff --git a/vllm/adapter_commons/worker_manager.py b/vllm/adapter_commons/worker_manager.py new file mode 100644 index 0000000000000..acf18993af6d7 --- /dev/null +++ b/vllm/adapter_commons/worker_manager.py @@ -0,0 +1,36 @@ +from abc import ABC, abstractmethod +from typing import Any, Optional, Set + +import torch + + +class AbstractWorkerManager(ABC): + + def __init__(self, device: torch.device): + self.device = device + + @property + @abstractmethod + def is_enabled(self) -> bool: + ... + + @abstractmethod + def set_active_adapters(self, requests: Set[Any], + mapping: Optional[Any]) -> None: + ... + + @abstractmethod + def add_adapter(self, adapter_request: Any) -> bool: + ... + + @abstractmethod + def remove_adapter(self, adapter_id: int) -> bool: + ... + + @abstractmethod + def remove_all_adapters(self): + ... + + @abstractmethod + def list_adapters(self) -> Set[int]: + ... diff --git a/vllm/config.py b/vllm/config.py index 1ea2888796808..68ca81a2ec4fe 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1285,6 +1285,39 @@ def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig): raise ValueError("LoRA is not supported with chunked prefill yet.") +@dataclass +class PromptAdapterConfig: + max_prompt_adapters: int + max_prompt_adapter_token: int + max_cpu_prompt_adapters: Optional[int] = None + prompt_adapter_dtype: Optional[torch.dtype] = None + + def __post_init__(self): + library_name = 'peft' + try: + __import__(library_name) + except ImportError as e: + raise ImportError( + f"'{library_name}' is not installed for prompt adapter support." + f"Please install it using 'pip install {library_name}'." + ) from e + + if self.max_prompt_adapters < 1: + raise ValueError(f"max_prompt_adapters " + f"({self.max_prompt_adapters}) must be >= 1.") + if self.max_prompt_adapter_token == 0: + raise ValueError("max_prompt_adapter_token must be set.") + if self.max_cpu_prompt_adapters is None: + self.max_cpu_prompt_adapters = self.max_prompt_adapters + + def verify_with_model_config(self, model_config: ModelConfig): + if self.prompt_adapter_dtype in (None, "auto"): + self.prompt_adapter_dtype = model_config.dtype + elif isinstance(self.prompt_adapter_dtype, str): + self.prompt_adapter_dtype = getattr(torch, + self.prompt_adapter_dtype) + + @dataclass class MultiModalConfig: """Configs the input data format and how models should run for @@ -1518,6 +1551,7 @@ class EngineConfig: speculative_config: Optional[SpeculativeConfig] decoding_config: Optional[DecodingConfig] observability_config: Optional[ObservabilityConfig] + prompt_adapter_config: Optional[PromptAdapterConfig] def __post_init__(self): """Verify configs are valid & consistent with each other. @@ -1529,6 +1563,9 @@ def __post_init__(self): self.lora_config.verify_with_model_config(self.model_config) self.lora_config.verify_with_scheduler_config( self.scheduler_config) + if self.prompt_adapter_config: + self.prompt_adapter_config.verify_with_model_config( + self.model_config) def to_dict(self): """Return the configs as a dictionary, for use in **kwargs. diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 9e626b2883975..6bda18cd4f061 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -11,6 +11,7 @@ from vllm.core.policy import Policy, PolicyFactory from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import (Sequence, SequenceData, SequenceGroup, SequenceGroupMetadata, SequenceStatus) @@ -139,6 +140,8 @@ def __post_init__(self): if self.num_loras > 0: self._sort_by_lora_ids() + self.num_prompt_adapters: int = len(self.prompt_adapter_requests) + def is_empty(self) -> bool: # NOTE: We do not consider the ignored sequence groups. return (not self.scheduled_seq_groups and not self.blocks_to_swap_in @@ -157,6 +160,14 @@ def lora_requests(self) -> Set[LoRARequest]: if g.seq_group.lora_request is not None } + @property + def prompt_adapter_requests(self) -> Set[PromptAdapterRequest]: + return { + g.seq_group.prompt_adapter_request + for g in self.scheduled_seq_groups + if g.seq_group.prompt_adapter_request is not None + } + @dataclass class SchedulerRunningOutputs: @@ -1024,6 +1035,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: # `multi_modal_data` will be None. multi_modal_data=seq_group.multi_modal_data if scheduler_outputs.num_prefill_groups > 0 else None, + prompt_adapter_request=seq_group.prompt_adapter_request, ) seq_group_metadata_list.append(seq_group_metadata) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index afa6892d49eb8..b972573c0258e 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -7,8 +7,8 @@ from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, EngineConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ObservabilityConfig, ParallelConfig, - SchedulerConfig, SpeculativeConfig, - TokenizerPoolConfig) + PromptAdapterConfig, SchedulerConfig, + SpeculativeConfig, TokenizerPoolConfig) from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.utils import FlexibleArgumentParser @@ -66,6 +66,9 @@ class EngineArgs: enable_lora: bool = False max_loras: int = 1 max_lora_rank: int = 16 + enable_prompt_adapter: bool = False + max_prompt_adapters: int = 1 + max_prompt_adapter_token: int = 0 fully_sharded_loras: bool = False lora_extra_vocab_size: int = 256 long_lora_scaling_factors: Optional[Tuple[float]] = None @@ -449,6 +452,17 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'Enabling this will use the fully sharded layers. ' 'At high sequence length, max rank or ' 'tensor parallel size, this is likely faster.')) + parser.add_argument('--enable-prompt-adapter', + action='store_true', + help='If True, enable handling of PromptAdapters.') + parser.add_argument('--max-prompt-adapters', + type=int, + default=EngineArgs.max_prompt_adapters, + help='Max number of PromptAdapters in a batch.') + parser.add_argument('--max-prompt-adapter-token', + type=int, + default=EngineArgs.max_prompt_adapter_token, + help='Max number of PromptAdapters tokens') parser.add_argument("--device", type=str, default=EngineArgs.device, @@ -726,6 +740,11 @@ def create_engine_config(self, ) -> EngineConfig: model_loader_extra_config=self.model_loader_extra_config, ) + prompt_adapter_config = PromptAdapterConfig( + max_prompt_adapters=self.max_prompt_adapters, + max_prompt_adapter_token=self.max_prompt_adapter_token) \ + if self.enable_prompt_adapter else None + decoding_config = DecodingConfig( guided_decoding_backend=self.guided_decoding_backend) @@ -751,6 +770,7 @@ def create_engine_config(self, ) -> EngineConfig: load_config=load_config, decoding_config=decoding_config, observability_config=observability_config, + prompt_adapter_config=prompt_adapter_config, ) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 33e40c7b3624a..9b4ef48b0e47e 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -18,6 +18,7 @@ from vllm.lora.request import LoRARequest from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.usage.usage_lib import UsageContext @@ -264,6 +265,7 @@ async def process_model_inputs_async( request_id: str, inputs: PromptInputs, lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> LLMInputs: if isinstance(inputs, str): inputs = {"prompt": inputs} @@ -279,6 +281,12 @@ async def process_model_inputs_async( else: prompt_token_ids = inputs["prompt_token_ids"] + if prompt_adapter_request: + prompt_token_ids = [ + 0 + ] * prompt_adapter_request.prompt_adapter_num_virtual_tokens + \ + prompt_token_ids + llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids, prompt=inputs.get("prompt"), multi_modal_data=inputs.get("multi_modal_data")) @@ -286,13 +294,14 @@ async def process_model_inputs_async( return self.input_processor(llm_inputs) async def add_request_async( - self, - request_id: str, - inputs: PromptInputs, - params: Union[SamplingParams, PoolingParams], - arrival_time: Optional[float] = None, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Dict[str, str]] = None, + self, + request_id: str, + inputs: PromptInputs, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Dict[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> None: if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " @@ -301,7 +310,10 @@ async def add_request_async( arrival_time = time.time() processed_inputs = await self.process_model_inputs_async( - request_id=request_id, inputs=inputs, lora_request=lora_request) + request_id=request_id, + inputs=inputs, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) self._add_processed_request( request_id=request_id, @@ -309,6 +321,7 @@ async def add_request_async( params=params, arrival_time=arrival_time, lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, trace_headers=trace_headers, ) @@ -627,6 +640,7 @@ async def add_request( arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Dict[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> AsyncStream: if self.log_requests: if isinstance(inputs, str): @@ -669,7 +683,7 @@ async def add_request( arrival_time=arrival_time, lora_request=lora_request, trace_headers=trace_headers, - ) + prompt_adapter_request=prompt_adapter_request) return stream @@ -680,6 +694,7 @@ async def generate( request_id: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Dict[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> AsyncIterator[RequestOutput]: """Generate outputs for a request. @@ -695,6 +710,8 @@ async def generate( request_id: The unique id of the request. lora_request: LoRA request to use for generation, if any. trace_headers: OpenTelemetry trace headers. + prompt_adapter_request: Prompt Adapter request to use + for generation, if any. Yields: The output `RequestOutput` objects from the LLMEngine @@ -749,6 +766,7 @@ async def generate( sampling_params, lora_request=lora_request, trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, ): yield LLMEngine.validate_output(output, RequestOutput) @@ -837,6 +855,7 @@ async def _process_request( *, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Dict[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]: """Common logic to process requests with SamplingParams or PoolingParams.""" @@ -849,6 +868,7 @@ async def _process_request( arrival_time=arrival_time, lora_request=lora_request, trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, ) try: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index de7604ece7c31..b476594fc73f6 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -8,7 +8,8 @@ from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, - ObservabilityConfig, ParallelConfig, SchedulerConfig, + ObservabilityConfig, ParallelConfig, + PromptAdapterConfig, SchedulerConfig, SpeculativeConfig) from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler, SchedulerOutputs) @@ -27,6 +28,7 @@ from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, RequestOutputFactory) from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, PoolerOutput, SamplerOutput, Sequence, @@ -93,6 +95,8 @@ class LLMEngine: decoding. executor_class: The model executor class for managing distributed execution. + prompt_adapter_config (Optional): The configuration related to serving + prompt adapters. log_stats: Whether to log statistics. usage_context: Specified entry point, used for usage info collection. """ @@ -161,6 +165,7 @@ def __init__( speculative_config: Optional[SpeculativeConfig], decoding_config: Optional[DecodingConfig], observability_config: Optional[ObservabilityConfig], + prompt_adapter_config: Optional[PromptAdapterConfig], executor_class: Type[ExecutorBase], log_stats: bool, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, @@ -222,6 +227,7 @@ def __init__( self.speculative_config = speculative_config self.load_config = load_config self.decoding_config = decoding_config or DecodingConfig() + self.prompt_adapter_config = prompt_adapter_config self.observability_config = observability_config or ObservabilityConfig( ) self.log_stats = log_stats @@ -250,6 +256,7 @@ def __init__( multimodal_config=multimodal_config, speculative_config=speculative_config, load_config=load_config, + prompt_adapter_config=prompt_adapter_config, ) if not self.model_config.embedding_mode: @@ -282,6 +289,8 @@ def __init__( # Feature flags "enable_lora": bool(lora_config), + "enable_prompt_adapter": + bool(prompt_adapter_config), "enable_prefix_caching": cache_config.enable_prefix_caching, "enforce_eager": @@ -376,7 +385,6 @@ def from_engine_args( engine_config = engine_args.create_engine_config() distributed_executor_backend = ( engine_config.parallel_config.distributed_executor_backend) - # Initialize the cluster and specify the executor class. if engine_config.device_config.device_type == "neuron": from vllm.executor.neuron_executor import NeuronExecutor @@ -409,7 +417,6 @@ def from_engine_args( else: from vllm.executor.gpu_executor import GPUExecutor executor_class = GPUExecutor - # Create the LLM engine. engine = cls( **engine_config.to_dict(), @@ -470,6 +477,9 @@ def _verify_args(self) -> None: self.lora_config.verify_with_model_config(self.model_config) self.lora_config.verify_with_scheduler_config( self.scheduler_config) + if self.prompt_adapter_config: + self.prompt_adapter_config.verify_with_model_config( + self.model_config) def _get_eos_token_id( self, lora_request: Optional[LoRARequest]) -> Optional[int]: @@ -487,6 +497,7 @@ def _add_processed_request( params: Union[SamplingParams, PoolingParams], arrival_time: float, lora_request: Optional[LoRARequest], + prompt_adapter_request: Optional[PromptAdapterRequest], trace_headers: Optional[Dict[str, str]] = None, ) -> None: # Create the sequences. @@ -495,7 +506,7 @@ def _add_processed_request( eos_token_id = self._get_eos_token_id(lora_request) seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id, - lora_request) + lora_request, prompt_adapter_request) # Create a SequenceGroup based on SamplingParams or PoolingParams if isinstance(params, SamplingParams): @@ -506,7 +517,7 @@ def _add_processed_request( arrival_time=arrival_time, lora_request=lora_request, trace_headers=trace_headers, - ) + prompt_adapter_request=prompt_adapter_request) elif isinstance(params, PoolingParams): seq_group = self._create_sequence_group_with_pooling( request_id, @@ -514,7 +525,7 @@ def _add_processed_request( params, arrival_time=arrival_time, lora_request=lora_request, - ) + prompt_adapter_request=prompt_adapter_request) else: raise ValueError( "Either SamplingParams or PoolingParams must be provided.") @@ -535,6 +546,7 @@ def process_model_inputs( request_id: str, inputs: PromptInputs, lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> LLMInputs: if isinstance(inputs, str): inputs = {"prompt": inputs} @@ -549,6 +561,11 @@ def process_model_inputs( else: prompt_token_ids = inputs["prompt_token_ids"] + if prompt_adapter_request: + prompt_token_ids = \ + [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens\ + + prompt_token_ids + llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids, prompt=inputs.get("prompt"), multi_modal_data=inputs.get("multi_modal_data")) @@ -563,6 +580,7 @@ def add_request( arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Dict[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> None: """Add a request to the engine's request pool. @@ -612,9 +630,11 @@ def add_request( if arrival_time is None: arrival_time = time.time() - processed_inputs = self.process_model_inputs(request_id=request_id, - inputs=inputs, - lora_request=lora_request) + processed_inputs = self.process_model_inputs( + request_id=request_id, + inputs=inputs, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) self._add_processed_request( request_id=request_id, @@ -622,6 +642,7 @@ def add_request( params=params, arrival_time=arrival_time, lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, trace_headers=trace_headers, ) @@ -633,6 +654,7 @@ def _create_sequence_group_with_sampling( arrival_time: float, lora_request: Optional[LoRARequest], trace_headers: Optional[Dict[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> SequenceGroup: """Creates a SequenceGroup with SamplingParams.""" max_logprobs = self.get_model_config().max_logprobs @@ -658,7 +680,7 @@ def _create_sequence_group_with_sampling( sampling_params=sampling_params, lora_request=lora_request, trace_headers=trace_headers, - ) + prompt_adapter_request=prompt_adapter_request) return seq_group @@ -669,16 +691,19 @@ def _create_sequence_group_with_pooling( pooling_params: PoolingParams, arrival_time: float, lora_request: Optional[LoRARequest], + prompt_adapter_request: Optional[PromptAdapterRequest], ) -> SequenceGroup: """Creates a SequenceGroup with PoolingParams.""" # Defensive copy of PoolingParams, which are used by the pooler pooling_params = pooling_params.clone() # Create the sequence group. - seq_group = SequenceGroup(request_id=request_id, - seqs=[seq], - arrival_time=arrival_time, - lora_request=lora_request, - pooling_params=pooling_params) + seq_group = SequenceGroup( + request_id=request_id, + seqs=[seq], + arrival_time=arrival_time, + lora_request=lora_request, + pooling_params=pooling_params, + prompt_adapter_request=prompt_adapter_request) return seq_group def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: @@ -1082,6 +1107,16 @@ def list_loras(self) -> Set[int]: def pin_lora(self, lora_id: int) -> bool: return self.model_executor.pin_lora(lora_id) + def add_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + return self.model_executor.add_prompt_adapter(prompt_adapter_request) + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + return self.model_executor.remove_prompt_adapter(prompt_adapter_id) + + def list_prompt_adapters(self) -> List[int]: + return self.model_executor.list_prompt_adapters() + def check_health(self) -> None: if self.tokenizer: self.tokenizer.check_health() diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index e3e506d496844..57e81a6317725 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -13,6 +13,7 @@ from vllm.lora.request import LoRARequest from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer import get_cached_tokenizer from vllm.usage.usage_lib import UsageContext @@ -255,6 +256,7 @@ def generate( prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None, use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> List[RequestOutput]: """Generates the completions for the input prompts. @@ -271,6 +273,8 @@ def generate( prompts and it is paired one by one with the prompt. use_tqdm: Whether to use tqdm to display the progress bar. lora_request: LoRA request to use for generation, if any. + prompt_adapter_request: Prompt Adapter request to use for + generation, if any. Returns: A list of `RequestOutput` objects containing the @@ -304,7 +308,7 @@ def generate( inputs=inputs, params=sampling_params, lora_request=lora_request, - ) + prompt_adapter_request=prompt_adapter_request) outputs = self._run_engine(use_tqdm=use_tqdm) return LLMEngine.validate_outputs(outputs, RequestOutput) @@ -397,6 +401,7 @@ def encode( prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None, use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> List[EmbeddingRequestOutput]: """Generates the completions for the input prompts. @@ -412,6 +417,8 @@ def encode( use the default pooling parameters. use_tqdm: Whether to use tqdm to display the progress bar. lora_request: LoRA request to use for generation, if any. + prompt_adapter_request: Prompt Adapter request to use for + generation, if any. Returns: A list of `EmbeddingRequestOutput` objects containing the @@ -445,6 +452,7 @@ def encode( inputs=inputs, params=pooling_params, lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, ) outputs = self._run_engine(use_tqdm=use_tqdm) @@ -504,6 +512,7 @@ def _validate_and_add_requests( params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams, Sequence[PoolingParams]], lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], + prompt_adapter_request: Optional[PromptAdapterRequest], ) -> None: if isinstance(inputs, (str, dict)): # Convert a single prompt to a list. @@ -526,19 +535,23 @@ def _validate_and_add_requests( params[i] if isinstance(params, Sequence) else params, lora_request=lora_request[i] if isinstance( lora_request, Sequence) else lora_request, - ) + prompt_adapter_request=prompt_adapter_request) def _add_request( - self, - inputs: PromptInputs, - params: Union[SamplingParams, PoolingParams], - lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + self, + inputs: PromptInputs, + params: Union[SamplingParams, PoolingParams], + lora_request: Optional[Union[List[LoRARequest], + LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> None: request_id = str(next(self.request_counter)) - self.llm_engine.add_request(request_id, - inputs, - params, - lora_request=lora_request) + self.llm_engine.add_request( + request_id, + inputs, + params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) def _run_engine( self, *, use_tqdm: bool diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index d3ed1ec7a15c5..6cba356c47063 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -116,7 +116,7 @@ async def detokenize(request: DetokenizeRequest): @app.get("/v1/models") async def show_available_models(): - models = await openai_serving_chat.show_available_models() + models = await openai_serving_completion.show_available_models() return JSONResponse(content=models.model_dump()) @@ -236,7 +236,8 @@ async def authentication(request: Request, call_next): args.lora_modules, args.chat_template) openai_serving_completion = OpenAIServingCompletion( - engine, model_config, served_model_names, args.lora_modules) + engine, model_config, served_model_names, args.lora_modules, + args.prompt_adapters) openai_serving_embedding = OpenAIServingEmbedding(engine, model_config, served_model_names) app.root_path = args.root_path diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 59ad73bf097c8..81c474ecc808a 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -9,7 +9,8 @@ import ssl from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str -from vllm.entrypoints.openai.serving_engine import LoRAModulePath +from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, + PromptAdapterPath) from vllm.utils import FlexibleArgumentParser @@ -23,6 +24,16 @@ def __call__(self, parser, namespace, values, option_string=None): setattr(namespace, self.dest, lora_list) +class PromptAdapterParserAction(argparse.Action): + + def __call__(self, parser, namespace, values, option_string=None): + adapter_list = [] + for item in values: + name, path = item.split('=') + adapter_list.append(PromptAdapterPath(name, path)) + setattr(namespace, self.dest, adapter_list) + + def make_arg_parser(): parser = FlexibleArgumentParser( description="vLLM OpenAI-Compatible RESTful API server.") @@ -65,6 +76,14 @@ def make_arg_parser(): action=LoRAParserAction, help="LoRA module configurations in the format name=path. " "Multiple modules can be specified.") + parser.add_argument( + "--prompt-adapters", + type=nullable_str, + default=None, + nargs='+', + action=PromptAdapterParserAction, + help="Prompt adapter configurations in the format name=path. " + "Multiple adapters can be specified.") parser.add_argument("--chat-template", type=nullable_str, default=None, diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 415bdbbd7c455..010d6f2ebb909 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -258,7 +258,7 @@ async def create_chat_completion( prompt=prompt, add_special_tokens=request.add_special_tokens) sampling_params = request.to_sampling_params() - lora_request = self._maybe_get_lora(request) + _, lora_request = self._maybe_get_adapter(request) decoding_config = await self.engine.get_decoding_config() guided_decoding_backend = request.guided_decoding_backend \ or decoding_config.guided_decoding_backend diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 9c719d634ac7d..b53b058b52af3 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -22,7 +22,8 @@ TokenizeResponse, UsageInfo) # yapf: enable from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, - OpenAIServing) + OpenAIServing, + PromptAdapterPath) from vllm.logger import init_logger from vllm.model_executor.guided_decoding import ( get_guided_decoding_logits_processor) @@ -67,11 +68,13 @@ class OpenAIServingCompletion(OpenAIServing): def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig, served_model_names: List[str], - lora_modules: Optional[List[LoRAModulePath]]): + lora_modules: Optional[List[LoRAModulePath]], + prompt_adapters: Optional[List[PromptAdapterPath]]): super().__init__(engine=engine, model_config=model_config, served_model_names=served_model_names, - lora_modules=lora_modules) + lora_modules=lora_modules, + prompt_adapters=prompt_adapters) async def create_completion(self, request: CompletionRequest, raw_request: Request): @@ -101,7 +104,12 @@ async def create_completion(self, request: CompletionRequest, generators: List[AsyncIterator[RequestOutput]] = [] try: sampling_params = request.to_sampling_params() - lora_request = self._maybe_get_lora(request) + adapter_type, adapter_request = self._maybe_get_adapter(request) + lora_request, prompt_adapter_request = None, None + if adapter_type == 'LoRA': + lora_request, prompt_adapter_request = adapter_request, None + elif adapter_type == 'PromptAdapter': + lora_request, prompt_adapter_request = None, adapter_request decoding_config = await self.engine.get_decoding_config() guided_decoding_backend = request.guided_decoding_backend \ or decoding_config.guided_decoding_backend @@ -147,6 +155,7 @@ async def create_completion(self, request: CompletionRequest, sampling_params, f"{request_id}-{i}", lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, trace_headers=trace_headers, ) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 8d281c51f02bc..58e6571d310e6 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -16,12 +16,19 @@ ModelPermission, TokenizeRequest) from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import Logprob from vllm.transformers_utils.tokenizer import get_tokenizer logger = init_logger(__name__) +@dataclass +class PromptAdapterPath: + name: str + local_path: str + + @dataclass class LoRAModulePath: name: str @@ -30,9 +37,14 @@ class LoRAModulePath: class OpenAIServing: - def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig, - served_model_names: List[str], - lora_modules: Optional[List[LoRAModulePath]]): + def __init__( + self, + engine: AsyncLLMEngine, + model_config: ModelConfig, + served_model_names: List[str], + lora_modules: Optional[List[LoRAModulePath]], + prompt_adapters: Optional[List[PromptAdapterPath]] = None, + ): super().__init__() self.engine = engine @@ -49,9 +61,8 @@ def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig, self.served_model_names = served_model_names - if lora_modules is None: - self.lora_requests = [] - else: + self.lora_requests = [] + if lora_modules is not None: self.lora_requests = [ LoRARequest( lora_name=lora.name, @@ -60,6 +71,20 @@ def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig, ) for i, lora in enumerate(lora_modules, start=1) ] + self.prompt_adapter_requests = [] + if prompt_adapters is not None: + for i, prompt_adapter in enumerate(prompt_adapters, start=1): + with open(f"./{prompt_adapter.local_path}" + f"/adapter_config.json") as f: + adapter_config = json.load(f) + num_virtual_tokens = adapter_config["num_virtual_tokens"] + self.prompt_adapter_requests.append( + PromptAdapterRequest( + prompt_adapter_name=prompt_adapter.name, + prompt_adapter_id=i, + prompt_adapter_local_path=prompt_adapter.local_path, + prompt_adapter_num_virtual_tokens=num_virtual_tokens)) + async def show_available_models(self) -> ModelList: """Show available models. Right now we only have one model.""" model_cards = [ @@ -75,7 +100,14 @@ async def show_available_models(self) -> ModelList: permission=[ModelPermission()]) for lora in self.lora_requests ] + prompt_adapter_cards = [ + ModelCard(id=prompt_adapter.prompt_adapter_name, + root=self.served_model_names[0], + permission=[ModelPermission()]) + for prompt_adapter in self.prompt_adapter_requests + ] model_cards.extend(lora_cards) + model_cards.extend(prompt_adapter_cards) return ModelList(data=model_cards) def create_error_response( @@ -109,20 +141,29 @@ async def _check_model( return None if request.model in [lora.lora_name for lora in self.lora_requests]: return None + if request.model in [ + prompt_adapter.prompt_adapter_name + for prompt_adapter in self.prompt_adapter_requests + ]: + return None return self.create_error_response( message=f"The model `{request.model}` does not exist.", err_type="NotFoundError", status_code=HTTPStatus.NOT_FOUND) - def _maybe_get_lora( + def _maybe_get_adapter( self, request: Union[CompletionRequest, ChatCompletionRequest, EmbeddingRequest] - ) -> Optional[LoRARequest]: + ) -> Tuple[Optional[str], Optional[Union[LoRARequest, + PromptAdapterRequest]]]: if request.model in self.served_model_names: - return None + return None, None for lora in self.lora_requests: if request.model == lora.lora_name: - return lora + return 'LoRA', lora + for prompt_adapter in self.prompt_adapter_requests: + if request.model == prompt_adapter.prompt_adapter_name: + return 'PromptAdapter', prompt_adapter # if _check_model has been called earlier, this will be unreachable raise ValueError(f"The model `{request.model}` does not exist.") diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index 3b5621f70b92d..d3b60e3ff4260 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -7,6 +7,7 @@ from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, make_async) @@ -48,6 +49,7 @@ def _init_worker(self): lora_config=self.lora_config, multimodal_config=self.multimodal_config, kv_cache_dtype=self.cache_config.cache_dtype, + prompt_adapter_config=self.prompt_adapter_config, is_driver_worker=True, ) self.driver_worker.init_device() @@ -90,6 +92,19 @@ def pin_lora(self, lora_id: int) -> bool: def list_loras(self) -> Set[int]: return self.driver_worker.list_loras() + def add_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + return self.driver_worker.add_prompt_adapter(prompt_adapter_request) + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + return self.driver_worker.remove_prompt_adapter(prompt_adapter_id) + + def list_prompt_adapters(self) -> Set[int]: + return self.driver_worker.list_prompt_adapters() + + def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: + return self.driver_worker.pin_prompt_adapter(prompt_adapter_id) + def check_health(self) -> None: # CPUExecutor will always be healthy as long as # it's running. diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index fc18dec0bca25..6f9e554459161 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -4,8 +4,10 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, - SchedulerConfig, SpeculativeConfig) + PromptAdapterConfig, SchedulerConfig, + SpeculativeConfig) from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import ExecuteModelRequest, SamplerOutput @@ -28,6 +30,7 @@ def __init__( lora_config: Optional[LoRAConfig], multimodal_config: Optional[MultiModalConfig], speculative_config: Optional[SpeculativeConfig], + prompt_adapter_config: Optional[PromptAdapterConfig], ) -> None: self.model_config = model_config self.cache_config = cache_config @@ -38,6 +41,7 @@ def __init__( self.device_config = device_config self.multimodal_config = multimodal_config self.speculative_config = speculative_config + self.prompt_adapter_config = prompt_adapter_config self._init_executor() @@ -95,6 +99,23 @@ def pin_lora(self, lora_id: int) -> bool: def list_loras(self) -> Set[int]: raise NotImplementedError + @abstractmethod + def add_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + raise NotImplementedError + + @abstractmethod + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + raise NotImplementedError + + @abstractmethod + def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: + raise NotImplementedError # type: ignore + + @abstractmethod + def list_prompt_adapters(self) -> Set[int]: + raise NotImplementedError + @abstractmethod def check_health(self) -> None: """Checks if the executor is healthy. If not, it should raise an @@ -122,12 +143,14 @@ def __init__( lora_config: Optional[LoRAConfig], multimodal_config: Optional[MultiModalConfig], speculative_config: Optional[SpeculativeConfig], + prompt_adapter_config: Optional[PromptAdapterConfig], ) -> None: self.pp_locks: Optional[List[asyncio.Lock]] = None super().__init__(model_config, cache_config, parallel_config, scheduler_config, device_config, load_config, - lora_config, multimodal_config, speculative_config) + lora_config, multimodal_config, speculative_config, + prompt_adapter_config) @abstractmethod async def execute_model_async( diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 7d3183a428a31..6ffc28d21be29 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -3,6 +3,7 @@ from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, make_async) @@ -45,6 +46,7 @@ def _get_worker_kwargs( lora_config=self.lora_config, multimodal_config=self.multimodal_config, speculative_config=self.speculative_config, + prompt_adapter_config=self.prompt_adapter_config, is_driver_worker=(not self.parallel_config) or (rank % self.parallel_config.tensor_parallel_size == 0), ) @@ -107,6 +109,25 @@ def pin_lora(self, lora_id: int) -> bool: def list_loras(self) -> Set[int]: return self.driver_worker.list_loras() + def add_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + assert prompt_adapter_request.prompt_adapter_id > 0, \ + "prompt_adapter_id must be greater than 0." + return self.driver_worker.add_prompt_adapter(prompt_adapter_request) + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + assert prompt_adapter_id > 0, \ + "prompt_adapter_id must be greater than 0." + return self.driver_worker.remove_prompt_adapter(prompt_adapter_id) + + def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: + assert prompt_adapter_id > 0, \ + "prompt_adapter_id must be greater than 0." + return self.driver_worker.pin_prompt_adapter(prompt_adapter_id) + + def list_prompt_adapters(self) -> Set[int]: + return self.driver_worker.list_prompt_adapters() + def check_health(self) -> None: # GPUExecutor will always be healthy as long as # it's running. diff --git a/vllm/executor/ray_xpu_executor.py b/vllm/executor/ray_xpu_executor.py index f02d4978371a3..33f9321b5ff36 100644 --- a/vllm/executor/ray_xpu_executor.py +++ b/vllm/executor/ray_xpu_executor.py @@ -8,7 +8,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, - SchedulerConfig, SpeculativeConfig) + PromptAdapterConfig, SchedulerConfig, + SpeculativeConfig) from vllm.executor.distributed_gpu_executor import ( # yapf: disable DistributedGPUExecutor, DistributedGPUExecutorAsync) from vllm.executor.ray_utils import RayWorkerWrapper, ray @@ -44,6 +45,7 @@ def __init__( load_config: LoadConfig, lora_config: Optional[LoRAConfig], multimodal_config: Optional[MultiModalConfig], + prompt_adapter_config: Optional[PromptAdapterConfig], speculative_config: Optional[SpeculativeConfig], ) -> None: assert device_config.device_type == "xpu" @@ -58,6 +60,7 @@ def __init__( self.scheduler_config = scheduler_config self.device_config = device_config self.multimodal_config = multimodal_config + self.prompt_adapter_config = prompt_adapter_config placement_group = self.parallel_config.placement_group diff --git a/vllm/executor/xpu_executor.py b/vllm/executor/xpu_executor.py index 29b246332ad55..f6550cce9ab1a 100644 --- a/vllm/executor/xpu_executor.py +++ b/vllm/executor/xpu_executor.py @@ -4,7 +4,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, - SchedulerConfig, SpeculativeConfig) + PromptAdapterConfig, SchedulerConfig, + SpeculativeConfig) from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.gpu_executor import GPUExecutor from vllm.logger import init_logger @@ -27,6 +28,7 @@ def __init__( load_config: LoadConfig, lora_config: Optional[LoRAConfig], multimodal_config: Optional[MultiModalConfig], + prompt_adapter_config: Optional[PromptAdapterConfig], speculative_config: Optional[SpeculativeConfig], ) -> None: assert device_config.device_type == "xpu" @@ -43,6 +45,7 @@ def __init__( self.scheduler_config = scheduler_config self.device_config = device_config self.multimodal_config = multimodal_config + self.prompt_adapter_config = prompt_adapter_config self.speculative_config = None # Instantiate the worker and load the model to GPU. diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 0a63f9ef012bc..40de134c0a5ee 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -8,6 +8,7 @@ import torch.nn.functional as F from transformers import PretrainedConfig +from vllm.adapter_commons.layers import AdapterMapping from vllm.config import LoRAConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -134,15 +135,8 @@ def _apply_lora_packed_nslice( @dataclass -class LoRAMapping: - # Per every token in input_ids: - index_mapping: Tuple[int, ...] - # Per sampled token: - prompt_mapping: Tuple[int, ...] - - def __post_init__(self): - self.index_mapping = tuple(self.index_mapping) - self.prompt_mapping = tuple(self.prompt_mapping) +class LoRAMapping(AdapterMapping): + pass class BaseLayerWithLoRA(nn.Module): diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 689835def83dd..e1ede7d4d710a 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -4,12 +4,17 @@ import os import re from dataclasses import dataclass, field -from typing import Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import safetensors.torch import torch from torch import nn +from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel, + AdapterModelManager) +from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter, + get_adapter, list_adapters, + remove_adapter, set_adapter_mapping) from vllm.config import LoRAConfig from vllm.logger import init_logger from vllm.lora.layers import (BaseLayerWithLoRA, @@ -19,7 +24,7 @@ from vllm.lora.utils import (from_layer, from_layer_logits_processor, parse_fine_tuned_lora_name, replace_submodule) from vllm.model_executor.models.interfaces import SupportsLoRA -from vllm.utils import LRUCache, is_pin_memory_available +from vllm.utils import is_pin_memory_available logger = init_logger(__name__) @@ -153,7 +158,7 @@ def get_lora_id(): return _GLOBAL_LORA_ID -class LoRAModel: +class LoRAModel(AdapterModel): """A LoRA fine-tuned model.""" def __init__( @@ -388,7 +393,7 @@ def from_local_checkpoint( ) -class LoRAModelManager: +class LoRAModelManager(AdapterModelManager): """A manager that manages multiple LoRA-fine-tuned models.""" def __init__( @@ -440,8 +445,7 @@ def __init__( # base_indices, sampler_indices, sampler_indices_padded, # embeddings_indices self.indices_len: List[Optional[int]] = [None] * 4 - - self.model = model + super().__init__(model) if hasattr(self.model, "supported_lora_modules"): self.supported_lora_modules = copy.deepcopy( self.model.supported_lora_modules) @@ -453,11 +457,11 @@ def __init__( self.model.packed_modules_mapping) self.packed_modules: Dict[str, List[str]] = {} self.modules: Dict[str, "BaseLayerWithLoRA"] = {} - self._registered_loras: Dict[int, LoRAModel] = {} # Dict instead of a Set for compatibility with LRUCache. - self._active_loras: Dict[int, None] = {} self._last_mapping: Optional[LoRAMapping] = None self._create_lora_modules() + self.model.lora_manager = self + self.adapter_type = 'LoRa' @property def capacity(self) -> int: @@ -467,15 +471,16 @@ def capacity(self) -> int: def lora_slots(self) -> int: return self.lora_config.max_loras - def __len__(self) -> int: - return len(self._registered_loras) + @property + def adapter_slots(self) -> int: + return self.lora_slots - def activate_lora( + def activate_adapter( self, lora_id: int, ) -> bool: """Move LoRA into a GPU buffer to be used in the forward pass.""" - if lora_id in self._active_loras: + if lora_id in self._active_adapters: return False first_free_slot = next( ((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id) @@ -483,8 +488,8 @@ def activate_lora( if first_free_slot is None: raise ValueError("No free lora slots") index, _ = first_free_slot - self._active_loras[lora_id] = None - lora_model = self._registered_loras[lora_id] + self._active_adapters[lora_id] = None + lora_model = self._registered_adapters[lora_id] logger.debug("Activating LoRA. int id: %d, slot index: %d", lora_model.id, index) self.lora_index_to_id[index] = lora_model.id @@ -498,21 +503,13 @@ def activate_lora( module.reset_lora(index) return True - def _deactivate_lora(self, lora_id: int): + def _deactivate_adapter(self, lora_id: int): try: index = self.lora_index_to_id.index(lora_id) self.lora_index_to_id[index] = None except ValueError: pass - def deactivate_lora(self, lora_id: int) -> bool: - """Remove a LoRA from a GPU buffer.""" - if lora_id in self._active_loras: - self._deactivate_lora(lora_id) - self._active_loras.pop(lora_id) - return True - return False - def _set_long_lora_context(self, lora: LoRAModel): if self.long_lora_context is None: return @@ -528,40 +525,19 @@ def _set_long_lora_context(self, lora: LoRAModel): if offsets: self.long_lora_context.offsets_by_lora_id[lora.id] = offsets - def _add_lora(self, lora: LoRAModel): + def _add_adapter(self, lora: LoRAModel): self._create_merged_loras_inplace(lora) - self._registered_loras[lora.id] = lora + self._registered_adapters[lora.id] = lora self._set_long_lora_context(lora) - def add_lora(self, lora: LoRAModel) -> bool: - """Add a LoRAModel to the manager CPU cache.""" - logger.debug( - "Adding lora. Model id: %d, " - "int id: %d, " - "scaling factor: %s", lora.id, lora.id, lora.scaling_factor) - if lora.id not in self._registered_loras: - if len(self._registered_loras) >= self.capacity: - raise RuntimeError("No free LoRA slots.") - self._add_lora(lora) - return True - return False - - def remove_lora(self, lora_id: int) -> bool: - """Remove a LoRAModel from the manager CPU cache.""" - # TODO: should we check active lora? - self.deactivate_lora(lora_id) - if self.long_lora_context: - self.long_lora_context.offsets_by_lora_id.pop(lora_id, None) - return bool(self._registered_loras.pop(lora_id, None)) - - def pin_lora(self, lora_id: int) -> bool: + def pin_adapter(self, lora_id: int) -> bool: """Pin a LoRAModel in the manager cache.""" raise NotImplementedError( "Pinning is not supported in LoRAModelManager." "Use LRUCacheLoRAModelManager for pinning") # type: ignore # TODO see if this can be vectorized - def _set_lora_mapping(self, mapping: LoRAMapping) -> None: + def _set_adapter_mapping(self, mapping: LoRAMapping) -> None: (base_indices, sampler_indices, sampler_indices_padded, embeddings_indices, long_lora_offsets_tensor, indices_len) = convert_mapping(mapping, self.lora_index_to_id, @@ -583,23 +559,11 @@ def _set_lora_mapping(self, mapping: LoRAMapping) -> None: # Maintain the reference self.indices_len[:] = indices_len - def set_lora_mapping(self, lora_mapping: LoRAMapping) -> None: - if self._last_mapping != lora_mapping: - self._set_lora_mapping(lora_mapping) - self._last_mapping = lora_mapping - - def list_loras(self) -> Dict[int, LoRAModel]: - """List all registered LoRAModels.""" - return dict(self._registered_loras) - - def get_lora(self, lora_id: int) -> Optional[LoRAModel]: - return self._registered_loras.get(lora_id, None) - - def remove_all_loras(self): + def remove_all_adapters(self): """Remove all LoRAModels from the manager.""" - self._registered_loras.clear() + self._registered_adapters.clear() self.lora_index_to_id = [None] * self.lora_slots - self._active_loras.clear() + self._active_adapters.clear() def _create_lora_modules(self): for module_name, module in self.model.named_modules( @@ -743,18 +707,39 @@ def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None: lora_model.loras[module_name] = PackedLoRALayerWeights.pack( replacement_loras) + def deactivate_adapter(self, adapter_id: int) -> bool: + return deactivate_adapter(adapter_id, self._active_adapters, + self._deactivate_adapter) + + def add_adapter(self, adapter: LoRAModel) -> bool: + logger.debug( + "Adding lora. Model id: %d, " + "int id: %d, " + "scaling factor: %s", adapter.id, adapter.id, + adapter.scaling_factor) + return add_adapter(adapter, self._registered_adapters, self.capacity, + self._add_adapter) -class LoRALRUCache(LRUCache[LoRAModel]): + def set_adapter_mapping(self, mapping: LoRAMapping) -> None: + self._last_mapping = set_adapter_mapping(mapping, self._last_mapping, + self._set_adapter_mapping) + + def remove_adapter(self, adapter_id: int) -> bool: + return remove_adapter(adapter_id, self._registered_adapters, + self.deactivate_adapter) + + def list_adapters(self) -> Dict[int, Any]: + return list_adapters(self._registered_adapters) + + def get_adapter(self, adapter_id: int) -> Optional[Any]: + return get_adapter(adapter_id, self._registered_adapters) + + +class LoRALRUCache(AdapterLRUCache[LoRAModel]): def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int], bool]): - super().__init__(capacity) - self.deactivate_lora_fn = deactivate_lora_fn - - def _on_remove(self, key: int, value: LoRAModel): - logger.debug("Removing LoRA. int id: %d", key) - self.deactivate_lora_fn(key) - return super()._on_remove(key, value) + super().__init__(capacity, deactivate_lora_fn) class LRUCacheLoRAModelManager(LoRAModelManager): @@ -770,49 +755,49 @@ def __init__( ): super().__init__(model, max_num_seqs, max_num_batched_tokens, vocab_size, lora_config) - self._registered_loras: LoRALRUCache = LoRALRUCache( - self.capacity, self.deactivate_lora) - self._active_loras: LoRALRUCache = LoRALRUCache( - self.lora_slots, self._deactivate_lora) + self._registered_adapters: LoRALRUCache = LoRALRUCache( + self.capacity, self.deactivate_adapter) + self._active_adapters: LoRALRUCache = LoRALRUCache( + self.lora_slots, self._deactivate_adapter) - def list_loras(self) -> Dict[int, LoRAModel]: + def list_adapters(self) -> Dict[int, LoRAModel]: """List all registered LoRAModels.""" - return dict(self._registered_loras.cache) + return dict(self._registered_adapters.cache) - def add_lora(self, lora: LoRAModel) -> bool: + def add_adapter(self, lora: LoRAModel) -> bool: """Add a LoRAModel to the manager.""" logger.debug( "Adding lora. Model id: %d, " "int id: %d, " "scaling factor: %s", lora.id, lora.id, lora.scaling_factor) - if lora.id not in self._registered_loras: - self._add_lora(lora) + if lora.id not in self._registered_adapters: + self._add_adapter(lora) was_added = True else: # We always touch to update the LRU cache order - self._registered_loras.touch(lora.id) + self._registered_adapters.touch(lora.id) was_added = False return was_added - def activate_lora( + def activate_adapter( self, lora_id: int, ) -> bool: - if lora_id not in self._active_loras and len( - self._active_loras) >= self.lora_slots: - self._active_loras.remove_oldest() - result = super().activate_lora(lora_id) + if lora_id not in self._active_adapters and len( + self._active_adapters) >= self.lora_slots: + self._active_adapters.remove_oldest() + result = super().activate_adapter(lora_id) # We always touch to update the LRU cache order - self._active_loras.touch(lora_id) + self._active_adapters.touch(lora_id) return result - def remove_oldest_lora(self) -> bool: - if len(self._registered_loras) > 0: - self._registered_loras.remove_oldest() + def remove_oldest_adapter(self) -> bool: + if len(self._registered_adapters) > 0: + self._registered_adapters.remove_oldest() return True return False - def pin_lora(self, lora_id: int) -> bool: + def pin_adapter(self, lora_id: int) -> bool: """Pin a LoRAModel in the manager cache.""" self._pin_lora_in_cpu_cache(lora_id) self._pin_lora_in_gpu_cache(lora_id) @@ -820,17 +805,17 @@ def pin_lora(self, lora_id: int) -> bool: def _pin_lora_in_cpu_cache(self, lora_id: int): try: - self._registered_loras.pin(lora_id) + self._registered_adapters.pin(lora_id) except ValueError as err: raise ValueError("Pinning failed. " f"LoRA {lora_id} is not registered.") from err def _pin_lora_in_gpu_cache(self, lora_id: int): - if lora_id not in self._active_loras: + if lora_id not in self._active_adapters: # move lora to gpu if not already active - self.activate_lora(lora_id) + self.activate_adapter(lora_id) - self._active_loras.pin(lora_id) + self._active_adapters.pin(lora_id) def create_lora_manager( diff --git a/vllm/lora/request.py b/vllm/lora/request.py index 662774ffe09ae..2d10d037760e2 100644 --- a/vllm/lora/request.py +++ b/vllm/lora/request.py @@ -1,13 +1,15 @@ from dataclasses import dataclass from typing import Optional +from vllm.adapter_commons.request import AdapterRequest + @dataclass -class LoRARequest: +class LoRARequest(AdapterRequest): """ Request for a LoRA adapter. - Note that this class should be be used internally. For online + Note that this class should be used internally. For online serving, it is recommended to not allow users to use this class but instead provide another layer of abstraction to prevent users from accessing unauthorized LoRA adapters. @@ -20,15 +22,16 @@ class LoRARequest: lora_int_id: int lora_local_path: str long_lora_max_len: Optional[int] = None + __hash__ = AdapterRequest.__hash__ - def __post_init__(self): - if self.lora_int_id < 1: - raise ValueError( - f"lora_int_id must be > 0, got {self.lora_int_id}") + @property + def adapter_id(self): + return self.lora_int_id - def __eq__(self, value: object) -> bool: - return isinstance( - value, LoRARequest) and self.lora_int_id == value.lora_int_id + @property + def name(self): + return self.lora_name - def __hash__(self) -> int: - return self.lora_int_id + @property + def local_path(self): + return self.lora_local_path diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index ca4903c23bcaa..3d0ef4252b024 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -1,12 +1,15 @@ -from abc import ABC, abstractmethod from contextlib import contextmanager from typing import Any, Dict, List, Literal, Optional, Set, Type, Union import torch +from vllm.adapter_commons.utils import (add_adapter_worker, + apply_adapters_worker, + list_adapters_worker, + set_active_adapters_worker) +from vllm.adapter_commons.worker_manager import AbstractWorkerManager from vllm.config import LoRAConfig from vllm.logger import init_logger -from vllm.lora.layers import LoRAMapping from vllm.lora.models import (LoRAModel, LoRAModelManager, LRUCacheLoRAModelManager, create_lora_manager) from vllm.lora.request import LoRARequest @@ -14,79 +17,13 @@ logger = init_logger(__name__) -class AbstractWorkerLoRAManager(ABC): - """Abstract class for managing LoRA models on the worker side.""" - - def __init__(self, - max_num_seqs: int, - max_num_batched_tokens: int, - vocab_size: int, - lora_config: LoRAConfig, - device: torch.device, - max_position_embeddings: Optional[int] = None): - self.max_num_seqs = max_num_seqs - self.max_num_batched_tokens = max_num_batched_tokens - self.max_position_embeddings = max_position_embeddings - self.vocab_size = vocab_size - self.device = device - self.lora_config = lora_config - - # If False, do not cache. If None, cache is empty. - self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False - - @contextmanager - def dummy_lora_cache(self): - """Use this context manager to reuse the dummy lora model - to avoid creating it repeatedly.""" - self._cached_dummy_lora = None - yield - self._cached_dummy_lora = False - - @property - @abstractmethod - def is_enabled(self) -> bool: - ... - - @abstractmethod - def create_lora_manager( - self, - model: torch.nn.Module, - ) -> Any: - ... - - @abstractmethod - def set_active_loras(self, lora_requests: Set[LoRARequest], - lora_mapping: LoRAMapping) -> None: - ... - - @abstractmethod - def add_lora(self, lora_request: LoRARequest) -> bool: - ... - - @abstractmethod - def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: - ... - - @abstractmethod - def remove_lora(self, lora_id: int) -> bool: - ... - - @abstractmethod - def remove_all_loras(self): - ... - - @abstractmethod - def list_loras(self) -> Set[int]: - ... - - -class WorkerLoRAManager(AbstractWorkerLoRAManager): +class WorkerLoRAManager(AbstractWorkerManager): """WorkerLoRAManager that manages LoRA models on the worker side. Every request, the requested LoRAs will be loaded (unless they are already loaded), and every other LoRA will be unloaded.""" - _lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager + _manager_cls: Type[LoRAModelManager] = LoRAModelManager def __init__( self, @@ -103,16 +40,23 @@ def __init__( self._lora_model_cls = lora_model_cls self.embedding_modules = embedding_modules self.embedding_padding_modules = embedding_padding_modules + self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False + self.max_num_seqs = max_num_seqs + self.max_num_batched_tokens = max_num_batched_tokens + self.vocab_size = vocab_size + self.lora_config = lora_config + self.max_position_embeddings = max_position_embeddings + super().__init__(device) # Lazily initialized by create_lora_manager. - self._lora_manager: LoRAModelManager - super().__init__( - max_num_seqs, - max_num_batched_tokens, - vocab_size, - lora_config, - device, - max_position_embeddings=max_position_embeddings, - ) + self._adapter_manager: LoRAModelManager + + @contextmanager + def dummy_lora_cache(self): + """Use this context manager to reuse the dummy lora model + to avoid creating it repeatedly.""" + self._cached_dummy_lora = None + yield + self._cached_dummy_lora = False @property def is_enabled(self) -> bool: @@ -128,41 +72,14 @@ def create_lora_manager( max_num_batched_tokens=self.max_num_batched_tokens, vocab_size=self.vocab_size, lora_config=self.lora_config, - lora_manager_cls=self._lora_manager_cls, + lora_manager_cls=self._manager_cls, ) - self._lora_manager = lora_manager + self._adapter_manager = lora_manager return lora_manager.model - def set_active_loras(self, lora_requests: Set[LoRARequest], - lora_mapping: LoRAMapping) -> None: - self._apply_loras(lora_requests) - self._lora_manager.set_lora_mapping(lora_mapping) - - def _apply_loras(self, lora_requests: Set[LoRARequest]) -> None: - loras_that_exist = self.list_loras() - loras_map = { - lora_request.lora_int_id: lora_request - for lora_request in lora_requests if lora_request - } - if len(loras_map) > self._lora_manager.lora_slots: - raise RuntimeError( - f"Number of requested LoRAs ({len(loras_map)}) is greater " - "than the number of GPU LoRA slots " - f"({self._lora_manager.lora_slots}).") - - new_loras = set(loras_map) - loras_to_add = new_loras - loras_that_exist - loras_to_remove = loras_that_exist - new_loras - - for lora_id in loras_to_remove: - self.remove_lora(lora_id) - - for lora_id in loras_to_add: - self.add_lora(loras_map[lora_id]) - - def _load_lora(self, lora_request: LoRARequest) -> LoRAModel: + def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: try: - model = self._lora_manager.model + model = self._adapter_manager.model supported_lora_modules = model.supported_lora_modules packed_modules_mapping = model.packed_modules_mapping expected_lora_modules: List[str] = [] @@ -198,37 +115,45 @@ def _load_lora(self, lora_request: LoRARequest) -> LoRAModel: return lora def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: - if lora_request.lora_int_id in self.list_loras(): + if lora_request.lora_int_id in self.list_adapters(): return False if isinstance(self._cached_dummy_lora, LoRAModel): dummy_lora = self._cached_dummy_lora.clone( lora_request.lora_int_id) else: - dummy_lora = self._lora_manager.create_dummy_lora( + dummy_lora = self._adapter_manager.create_dummy_lora( lora_request.lora_int_id, rank, 1, self.embedding_modules) if self._cached_dummy_lora is None: self._cached_dummy_lora = dummy_lora - return self._lora_manager.add_lora(dummy_lora) + return self._adapter_manager.add_adapter(dummy_lora) - def add_lora(self, lora_request: LoRARequest) -> bool: - if lora_request.lora_int_id in self.list_loras(): - return False - lora = self._load_lora(lora_request) - loaded = self._lora_manager.add_lora(lora) - self._lora_manager.activate_lora(lora.id) - return loaded + def pin_adapter(self, adapter_id: int) -> bool: + return self._adapter_manager.pin_adapter(adapter_id) + + def set_active_adapters(self, requests: Set[Any], + mapping: Optional[Any]) -> None: + set_active_adapters_worker(requests, mapping, self._apply_adapters, + self._adapter_manager.set_adapter_mapping) + + def _apply_adapters(self, adapter_requests: Set[Any]) -> None: + apply_adapters_worker(adapter_requests, self.list_adapters, + self._adapter_manager.adapter_slots, + self.remove_adapter, self.add_adapter) - def remove_lora(self, lora_id: int) -> bool: - return self._lora_manager.remove_lora(lora_id) + def add_adapter(self, adapter_request: Any) -> bool: + return add_adapter_worker(adapter_request, self.list_adapters, + self._load_adapter, + self._adapter_manager.add_adapter, + self._adapter_manager.activate_adapter) - def pin_lora(self, lora_id: int) -> bool: - return self._lora_manager.pin_lora(lora_id) + def remove_adapter(self, adapter_id: int) -> bool: + return self._adapter_manager.remove_adapter(adapter_id) - def remove_all_loras(self): - self._lora_manager.remove_all_loras() + def remove_all_adapters(self): + self._adapter_manager.remove_all_adapters() - def list_loras(self) -> Set[int]: - return set(self._lora_manager.list_loras()) + def list_adapters(self) -> Set[int]: + return list_adapters_worker(self._adapter_manager.list_adapters) class LRUCacheWorkerLoRAManager(WorkerLoRAManager): @@ -238,8 +163,7 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager): (unless they are already loaded) and least recently used LoRAs will be unloaded if the cache is above capacity.""" - _lora_manager_cls: Type[ - LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager + _manager_cls: Type[LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager def create_lora_manager( self, @@ -247,40 +171,41 @@ def create_lora_manager( ) -> Any: lora_manager = create_lora_manager( model, - lora_manager_cls=self._lora_manager_cls, + lora_manager_cls=self._manager_cls, max_num_seqs=self.max_num_seqs, vocab_size=self.vocab_size, lora_config=self.lora_config, max_num_batched_tokens=self.max_num_batched_tokens, ) - self._lora_manager = lora_manager + self._adapter_manager = lora_manager return lora_manager.model - def _apply_loras(self, lora_requests: Set[LoRARequest]) -> None: + def _apply_adapters(self, lora_requests: Set[LoRARequest]) -> None: loras_map = { lora_request.lora_int_id: lora_request for lora_request in lora_requests if lora_request } - if len(loras_map) > self._lora_manager.lora_slots: + if len(loras_map) > self._adapter_manager.lora_slots: raise RuntimeError( f"Number of requested LoRAs ({len(loras_map)}) is greater " "than the number of GPU LoRA slots " - f"({self._lora_manager.lora_slots}).") + f"({self._adapter_manager.lora_slots}).") for lora in loras_map.values(): - self.add_lora(lora) + self.add_adapter(lora) - def add_lora(self, lora_request: LoRARequest) -> bool: - if lora_request.lora_int_id not in self.list_loras(): + def add_adapter(self, lora_request: LoRARequest) -> bool: + if lora_request.lora_int_id not in self.list_adapters(): # Remove before we load the new lora to save memory - if len(self._lora_manager) + 1 > self._lora_manager.capacity: - assert isinstance(self._lora_manager, LRUCacheLoRAModelManager) - self._lora_manager.remove_oldest_lora() - lora = self._load_lora(lora_request) - loaded = self._lora_manager.add_lora(lora) + if len(self._adapter_manager) + 1 > self._adapter_manager.capacity: + assert isinstance(self._adapter_manager, + LRUCacheLoRAModelManager) + self._adapter_manager.remove_oldest_adapter() + lora = self._load_adapter(lora_request) + loaded = self._adapter_manager.add_adapter(lora) else: # If the lora is already loaded, just touch it to # update its position in the caches - loaded = self._lora_manager.get_lora( + loaded = self._adapter_manager.get_adapter( lora_request.lora_int_id) is not None - self._lora_manager.activate_lora(lora_request.lora_int_id) + self._adapter_manager.activate_adapter(lora_request.lora_int_id) return loaded diff --git a/vllm/prompt_adapter/__init__.py b/vllm/prompt_adapter/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/prompt_adapter/layers.py b/vllm/prompt_adapter/layers.py new file mode 100644 index 0000000000000..27a61e692e1b7 --- /dev/null +++ b/vllm/prompt_adapter/layers.py @@ -0,0 +1,80 @@ +from dataclasses import dataclass +from typing import Optional + +import torch +from torch import nn + +from vllm.adapter_commons.layers import AdapterMapping +from vllm.config import PromptAdapterConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) + + +@dataclass +class PromptAdapterMapping(AdapterMapping): + pass + + +class VocabParallelEmbeddingWithPromptAdapter(nn.Module): + + def __init__(self, base_layer: VocabParallelEmbedding) -> None: + super().__init__() + self.base_layer = base_layer + self.emb_layer = self.base_layer + if 'LoRA' in base_layer.__class__.__name__: + self.emb_layer = self.base_layer.base_layer + + def create_prompt_adapter_weights( + self, prompt_adapter_config: PromptAdapterConfig): + self.embeddings_tensors = torch.zeros( + ( + prompt_adapter_config.max_prompt_adapters, + prompt_adapter_config.max_prompt_adapter_token, + self.emb_layer.embedding_dim, + ), + dtype=self.emb_layer.weight.dtype, + device=self.emb_layer.weight.device, + ) + self.adapter_lengths = torch.zeros( + prompt_adapter_config.max_prompt_adapters, + dtype=torch.long, + device=self.emb_layer.weight.device) + + self.indices_gpu: torch.Tensor + self.embedding_indices_gpu: torch.Tensor + + def reset_prompt_adapter(self, index: int): + self.embeddings_tensors[index] = 0 + + def set_prompt_adapter( + self, + index: int, + adapter_model: Optional[torch.Tensor], + ): + self.reset_prompt_adapter(index) + if adapter_model is not None: + length = adapter_model.shape[0] + self.embeddings_tensors[index, :length] = adapter_model + self.adapter_lengths[index] = length + + def set_mapping( + self, + prompt_indices: torch.Tensor, + prompt_embedding_indices: torch.Tensor, + ): + self.indices_gpu = prompt_indices.to( + device=self.emb_layer.weight.device) + self.embedding_indices_gpu = prompt_embedding_indices.to( + device=self.emb_layer.weight.device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + hidden_states = self.base_layer(x) + if self.embedding_indices_gpu.ndim > 1: + valid_mask = self.indices_gpu != -1 + gathered_embeddings = self.embeddings_tensors[ + self.embedding_indices_gpu[:, 0], + self.embedding_indices_gpu[:, 1]] + + # Update hidden states + hidden_states[valid_mask] = gathered_embeddings + return hidden_states \ No newline at end of file diff --git a/vllm/prompt_adapter/models.py b/vllm/prompt_adapter/models.py new file mode 100644 index 0000000000000..93eb3bde646ac --- /dev/null +++ b/vllm/prompt_adapter/models.py @@ -0,0 +1,355 @@ +import logging +import math +from typing import Any, Callable, Dict, List, Optional, Type + +import torch +from torch import nn + +from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel, + AdapterModelManager) +from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter, + get_adapter, list_adapters, + remove_adapter, set_adapter_mapping) +from vllm.config import PromptAdapterConfig +from vllm.prompt_adapter.layers import ( + VocabParallelEmbeddingWithPromptAdapter) # yapf: disable +from vllm.prompt_adapter.layers import PromptAdapterMapping + +logger = logging.getLogger(__name__) + +_GLOBAL_PROMPT_ADAPTER_ID = 0 + + +def get_prompt_adapter_id(): + global _GLOBAL_PROMPT_ADAPTER_ID + _GLOBAL_PROMPT_ADAPTER_ID += 1 + return _GLOBAL_PROMPT_ADAPTER_ID + + +def convert_to_embedding_indices(indices): + embedding_indices = [] + count = 0 + + for value in indices: + if value == -1: + count = 0 + else: + embedding_indices.append([value, count]) + count += 1 + + return torch.tensor(embedding_indices) + + +def convert_mapping( + mapping: PromptAdapterMapping, + prompt_adapter_index_to_id: List[Optional[int]], +) -> torch.Tensor: + """Converts PromptAdapterMapping to index tensors. + + Args: + mapping: PromptAdapterMapping mapping rows in a + batch to PromptAdapter ids. + prompt_adapter_index_to_id: List mapping PromptAdapter + ids to PromptAdapter indices. + + Returns: + pa_indices: Tensor of shape [batch_size] mapping batch rows to + PromptAdapter indices. + """ + id_to_index = { + id_: idx + for idx, id_ in enumerate(prompt_adapter_index_to_id) + if id_ is not None + } + pa_indices = ([ + id_to_index.get(id_, -1) if id_ > 0 else -1 + for id_ in mapping.index_mapping + ]) + + pa_embedding_mapping = convert_to_embedding_indices(pa_indices) + pa_indices = torch.tensor(pa_indices) + return pa_indices, pa_embedding_mapping + + +class PromptAdapterModel(AdapterModel): + + def __init__(self, + prompt_adapter_id=None, + num_virtual_tokens=None, + prompt_embedding=None) -> None: + self.id = prompt_adapter_id + self.prompt_embedding = prompt_embedding + self.num_virtual_tokens = num_virtual_tokens + + @classmethod + def from_local_checkpoint( + cls, + adapter_model_path: str, + prompt_adapter_id: int, + num_virtual_tokens: int, + config: PromptAdapterConfig, + device: str = "cuda", + ) -> "PromptAdapterModel": + from peft.utils import load_peft_weights + + if num_virtual_tokens > config.max_prompt_adapter_token: + raise ValueError( + f'num_virtual_tokens ({num_virtual_tokens}) should be <= ' + f'max_prompt_adapter_token({config.max_prompt_adapter_token})') + + adapters_weights = load_peft_weights(adapter_model_path, device) + prompt_embedding = adapters_weights["prompt_embeddings"].to( + config.prompt_adapter_dtype) + + return cls(prompt_adapter_id, num_virtual_tokens, prompt_embedding) + + +class PromptAdapterModelManager(AdapterModelManager): + """A manager that manages multiple Prompt Adapter models.""" + + def __init__( + self, + model: nn.Module, + max_num_seqs: int, + max_num_batched_tokens: int, + prompt_adapter_config: PromptAdapterConfig, + ): + """Create a PromptAdapterModel and adapter for a given model. + + Args: + model: the model to be adapted. + max_num_seqs: the maximum number of sequences model can run in a + single batch. + max_num_batched_tokens: the maximum number of tokens model can run + in a single batch. + prompt_adapter_config: the PromptAdapter config, + """ + self.model: nn.Module = model + # Dict instead of a Set for compatibility with LRUCache. + self.prompt_adapter_index_to_id: List[ + Optional[int]] = [None] * self.prompt_adapter_slots + self.max_num_seqs = max_num_seqs + self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8 + self.prompt_adapter_config = prompt_adapter_config + self.model.prompt_adapter_manager = self + self.adapter_type = 'PromptAdapter' + + self.base_indices = torch.tensor([-1]) + self.base_embedding_indices = torch.tensor([]) + + self.modules: Dict[str, nn.Module] = {} + self._create_prompt_adapter_modules() + self._last_mapping: Optional[PromptAdapterMapping] = None + + @property + def prompt_adapter_slots(self) -> int: + return self.prompt_adapter_config.max_prompt_adapters + + @property + def adapter_slots(self) -> int: + return self.prompt_adapter_slots + + @property + def capacity(self) -> int: + return self.prompt_adapter_config.max_cpu_prompt_adapters + + def activate_adapter( + self, + prompt_adapter_id: int, + ) -> bool: + """Move PromptAdapter into a GPU buffer + to be used in the forward pass.""" + if prompt_adapter_id in self._active_adapters: + return False + first_free_slot = next( + ((i, prompt_adapter_id) for i, prompt_adapter_id in enumerate( + self.prompt_adapter_index_to_id) if prompt_adapter_id is None), + None) + if first_free_slot is None: + raise ValueError("No free prompt_adapter slots") + index, _ = first_free_slot + self._active_adapters[prompt_adapter_id] = None + prompt_adapter_model = (self._registered_adapters[prompt_adapter_id]) + logger.debug("Activating prompt_adapter. int id: %d, slot index: %d", + prompt_adapter_model.id, index) + self.prompt_adapter_index_to_id[index] = prompt_adapter_model.id + for _, v in self.modules.items(): + v.set_prompt_adapter(index, prompt_adapter_model.prompt_embedding) + return True + + def _deactivate_adapter(self, prompt_adapter_id: int): + try: + index = self.prompt_adapter_index_to_id.index(prompt_adapter_id) + self.prompt_adapter_index_to_id[index] = None + for _, v in self.modules.items(): + v.reset_prompt_adapter(index) + except ValueError: + pass + + def _add_adapter(self, prompt_adapter: PromptAdapterModel): + self._registered_adapters[prompt_adapter.id] = prompt_adapter + + def _set_adapter_mapping(self, mapping: PromptAdapterMapping) -> None: + base_indices, base_embedding_indices = convert_mapping( + mapping, self.prompt_adapter_index_to_id) + for k, v in self.modules.items(): + v.set_mapping(base_indices, base_embedding_indices) + + def _create_prompt_adapter_modules(self): + for module_name, module in self.model.named_modules( + remove_duplicate=False): + if "VocabParallel" in module.__class__.__name__: + new_module = VocabParallelEmbeddingWithPromptAdapter(module) + new_module.create_prompt_adapter_weights( + self.prompt_adapter_config) + replaced_module = self.replace_submodule( + self.model, module_name, new_module) + self.register_module(module.__class__.__name__, + replaced_module) + replaced_module.set_mapping(self.base_indices, + self.base_embedding_indices) + break + + def replace_submodule(self, model: nn.Module, module_name: str, + new_module: nn.Module) -> nn.Module: + """Replace a submodule in a model with a new module.""" + parent = model.get_submodule(".".join(module_name.split(".")[:-1])) + target_name = module_name.split(".")[-1] + setattr(parent, target_name, new_module) + return new_module + + def register_module(self, module_name: str, module: nn.Module): + self.modules[module_name] = module + + def pin_adapter(self, prompt_adapter_id: int) -> bool: + """Pin a PromptAdapterModel in the manager cache.""" + raise NotImplementedError( + "Pinning is not supported in PromptAdapterModelManager." + "Use LRUCachePromptAdapterModelManager for pinning" + ) # type: ignore + + def remove_all_adapters(self): + """Remove all PromptAdapterModel from the manager.""" + self._registered_adapters.clear() + self.prompt_adapter_index_to_id = [None] * self.prompt_adapter_slots + self._active_adapters.clear() + + def deactivate_adapter(self, adapter_id: int) -> bool: + return deactivate_adapter(adapter_id, self._active_adapters, + self._deactivate_adapter) + + def add_adapter(self, adapter: PromptAdapterModel) -> bool: + return add_adapter(adapter, self._registered_adapters, self.capacity, + self._add_adapter) + + def set_adapter_mapping(self, mapping: PromptAdapterMapping) -> None: + self._last_mapping = set_adapter_mapping(mapping, self._last_mapping, + self._set_adapter_mapping) + + def remove_adapter(self, adapter_id: int) -> bool: + return remove_adapter(adapter_id, self._registered_adapters, + self.deactivate_adapter) + + def list_adapters(self) -> Dict[int, Any]: + return list_adapters(self._registered_adapters) + + def get_adapter(self, adapter_id: int) -> Optional[Any]: + return get_adapter(adapter_id, self._registered_adapters) + + +class PromptAdapterLRUCache(AdapterLRUCache[PromptAdapterModel]): + + def __init__(self, capacity: int, + deactivate_prompt_adapter_fn: Callable[[int], bool]): + super().__init__(capacity, deactivate_prompt_adapter_fn) + + +class LRUCachePromptAdapterModelManager(PromptAdapterModelManager): + """A model manager that manages multiple prompt_adapters with LRU cache.""" + + def __init__( + self, + model: nn.Module, + max_num_seqs: int, + max_num_batched_tokens: int, + prompt_adapter_config: PromptAdapterConfig, + ): + self.prompt_adapter_config = prompt_adapter_config + super().__init__(model, max_num_seqs, max_num_batched_tokens, + prompt_adapter_config) + self._registered_adapters = PromptAdapterLRUCache( + self.capacity, self.deactivate_adapter) + self._active_adapters = PromptAdapterLRUCache( + self.prompt_adapter_slots, self._deactivate_adapter) + + def list_adapters(self) -> Dict[int, PromptAdapterModel]: + """List all registered PromptAdapterModel.""" + return dict(self._registered_adapters.cache) + + def add_adapter(self, prompt_adapter: PromptAdapterModel) -> bool: + """Add a PromptAdapterModel to the manager.""" + if prompt_adapter.id not in self._registered_adapters: + self._add_adapter(prompt_adapter) + was_added = True + else: + # We always touch to update the LRU cache order + self._registered_adapters.touch(prompt_adapter.id) + was_added = False + return was_added + + def activate_adapter( + self, + prompt_adapter_id: int, + ) -> bool: + if prompt_adapter_id not in self._active_adapters and len( + self._active_adapters) >= self.prompt_adapter_slots: + self._active_adapters.remove_oldest() + result = super().activate_adapter(prompt_adapter_id) + # We always touch to update the LRU cache order + self._active_adapters.touch(prompt_adapter_id) + return result + + def remove_oldest_adapter(self) -> bool: + if len(self._registered_adapters) > 0: + self._registered_adapters.remove_oldest() + return True + return False + + def pin_adapter(self, prompt_adapter_id: int) -> bool: + """Pin a PromptAdapterModel in the manager cache.""" + self._pin_prompt_adapter_in_cpu_cache(prompt_adapter_id) + self._pin_prompt_adapter_in_gpu_cache(prompt_adapter_id) + return True + + def _pin_prompt_adapter_in_cpu_cache(self, prompt_adapter_id: int): + try: + self._registered_adapters.pin(prompt_adapter_id) + except ValueError as err: + raise ValueError( + "Pinning failed. " + f"Prompt Adapter {prompt_adapter_id} is not registered." + ) from err + + def _pin_prompt_adapter_in_gpu_cache(self, prompt_adapter_id: int): + if prompt_adapter_id not in self._active_adapters: + # move adapter to gpu if not already active + self.activate_adapter(prompt_adapter_id) + self._active_adapters.pin(prompt_adapter_id) + + +def create_prompt_adapter_manager( + model: nn.Module, + max_num_seqs: int, + max_num_batched_tokens: int, + prompt_adapter_config: PromptAdapterConfig, + prompt_adapter_manager_cls: Type[ + PromptAdapterModelManager] = PromptAdapterModelManager, + **kwargs) -> PromptAdapterModelManager: + """Create a PromptAdapterModel for a given model.""" + prompt_adapter_manager = prompt_adapter_manager_cls( + model=model, + max_num_seqs=max_num_seqs, + max_num_batched_tokens=max_num_batched_tokens, + prompt_adapter_config=prompt_adapter_config, + **kwargs) + return prompt_adapter_manager diff --git a/vllm/prompt_adapter/request.py b/vllm/prompt_adapter/request.py new file mode 100644 index 0000000000000..c0c98cf72bbae --- /dev/null +++ b/vllm/prompt_adapter/request.py @@ -0,0 +1,30 @@ +from dataclasses import dataclass + +from vllm.adapter_commons.request import AdapterRequest + + +@dataclass +class PromptAdapterRequest(AdapterRequest): + """ + Request for a Prompt adapter. + """ + + prompt_adapter_name: str + prompt_adapter_id: int + prompt_adapter_local_path: str + prompt_adapter_num_virtual_tokens: int + + def __hash__(self): + return super().__hash__() + + @property + def adapter_id(self): + return self.prompt_adapter_id + + @property + def name(self): + return self.prompt_adapter_name + + @property + def local_path(self): + return self.prompt_adapter_local_path diff --git a/vllm/prompt_adapter/worker_manager.py b/vllm/prompt_adapter/worker_manager.py new file mode 100644 index 0000000000000..ddc1ef893c6f2 --- /dev/null +++ b/vllm/prompt_adapter/worker_manager.py @@ -0,0 +1,176 @@ +import logging +from typing import Any, Optional, Set, Type + +import torch + +from vllm.adapter_commons.utils import (add_adapter_worker, + apply_adapters_worker, + list_adapters_worker, + set_active_adapters_worker) +from vllm.adapter_commons.worker_manager import AbstractWorkerManager +from vllm.config import PromptAdapterConfig +from vllm.prompt_adapter.models import (LRUCachePromptAdapterModelManager, + PromptAdapterModel, + PromptAdapterModelManager, + create_prompt_adapter_manager) +from vllm.prompt_adapter.request import PromptAdapterRequest + +logger = logging.getLogger(__name__) + + +class WorkerPromptAdapterManager(AbstractWorkerManager): + """WorkerPromptAdapterManager that manages + prompt_adapter models on the worker side. + + Every request, the requested prompt_adapters will be + loaded (unless they are already loaded), + and every other prompt_adapter will be unloaded.""" + + _manager_cls: Type[PromptAdapterModelManager] = PromptAdapterModelManager + + def __init__( + self, + max_num_seqs: int, + max_num_batched_tokens: int, + device: torch.device, + prompt_adapter_config: PromptAdapterConfig, + prompt_adapter_model_cls: Type[PromptAdapterModel] = PromptAdapterModel + ): + self._adapter_manager: PromptAdapterModelManager + self.max_num_seqs = max_num_seqs + self.max_num_batched_tokens = max_num_batched_tokens + self._prompt_adapter_model_cls = prompt_adapter_model_cls + self.prompt_adapter_config = prompt_adapter_config + super().__init__(device) + + @property + def is_enabled(self) -> bool: + return True + + def create_prompt_adapter_manager( + self, + model: torch.nn.Module, + ) -> Any: + prompt_adapter_manager = create_prompt_adapter_manager( + model, + max_num_seqs=self.max_num_seqs, + max_num_batched_tokens=self.max_num_batched_tokens, + prompt_adapter_config=self.prompt_adapter_config, + prompt_adapter_manager_cls=self._manager_cls, + ) + self._adapter_manager = prompt_adapter_manager + return prompt_adapter_manager.model + + def _load_adapter( + self, prompt_adapter_request: PromptAdapterRequest + ) -> PromptAdapterModel: + try: + prompt_adapter = ( + self._prompt_adapter_model_cls.from_local_checkpoint( + prompt_adapter_request.prompt_adapter_local_path, + prompt_adapter_id=prompt_adapter_request.prompt_adapter_id, + num_virtual_tokens=prompt_adapter_request. + prompt_adapter_num_virtual_tokens, + config=self.prompt_adapter_config, + device=str(self.device), + )) + except Exception as e: + raise RuntimeError( + f"Loading prompt_adapter " + f"{prompt_adapter_request.prompt_adapter_local_path}" + f" failed") from e + return prompt_adapter + + def add_dummy_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + return True + + def pin_adapter(self, adapter_id: int) -> bool: + return self._adapter_manager.pin_adapter(adapter_id) + + def set_active_adapters(self, requests: Set[Any], + mapping: Optional[Any]) -> None: + set_active_adapters_worker(requests, mapping, self._apply_adapters, + self._adapter_manager.set_adapter_mapping) + + def add_adapter(self, adapter_request: Any) -> bool: + return add_adapter_worker(adapter_request, self.list_adapters, + self._load_adapter, + self._adapter_manager.add_adapter, + self._adapter_manager.activate_adapter) + + def _apply_adapters(self, adapter_requests: Set[Any]) -> None: + apply_adapters_worker(adapter_requests, self.list_adapters, + self._adapter_manager.adapter_slots, + self.remove_adapter, self.add_adapter) + + def remove_adapter(self, adapter_id: int) -> bool: + return self._adapter_manager.remove_adapter(adapter_id) + + def remove_all_adapters(self): + self._adapter_manager.remove_all_adapters() + + def list_adapters(self) -> Set[int]: + return list_adapters_worker(self._adapter_manager.list_adapters) + + +class LRUCacheWorkerPromptAdapterManager(WorkerPromptAdapterManager): + """WorkerPromptAdapterManager that manages + prompt_adapter models on the worker side. + + Uses an LRU Cache. Every request, the requested + prompt_adapters will be loaded (unless they are already loaded) + and least recently used prompt_adapters will + be unloaded if the cache is above capacity.""" + + _prompt_adapter_manager_cls: Type[ + LRUCachePromptAdapterModelManager] = LRUCachePromptAdapterModelManager + + def create_prompt_adapter_manager( + self, + model: torch.nn.Module, + ) -> Any: + prompt_adapter_manager = create_prompt_adapter_manager( + model, + max_num_seqs=self.max_num_seqs, + max_num_batched_tokens=self.max_num_batched_tokens, + prompt_adapter_config=self.prompt_adapter_config, + prompt_adapter_manager_cls=self._prompt_adapter_manager_cls) + self._adapter_manager: LRUCachePromptAdapterModelManager = ( + prompt_adapter_manager) + return prompt_adapter_manager.model + + def _apply_adapters( + self, prompt_adapter_requests: Set[PromptAdapterRequest]) -> None: + prompt_adapters_map = { + prompt_adapter_request.prompt_adapter_id: prompt_adapter_request + for prompt_adapter_request in prompt_adapter_requests + if prompt_adapter_request + } + if len(prompt_adapters_map + ) > self._adapter_manager.prompt_adapter_slots: + raise RuntimeError( + f"Number of requested prompt_adapters " + f"({len(prompt_adapters_map)}) is greater " + "than the number of GPU prompt_adapter slots " + f"({self._adapter_manager.prompt_adapter_slots}).") + for prompt_adapter in prompt_adapters_map.values(): + self.add_adapter(prompt_adapter) + + def add_adapter(self, + prompt_adapter_request: PromptAdapterRequest) -> bool: + if prompt_adapter_request.prompt_adapter_id not in self.list_adapters( + ): + # Remove before we load the new prompt_adapter to save memory + if len(self._adapter_manager) + 1 > self._adapter_manager.capacity: + self._adapter_manager.remove_oldest_adapter() + prompt_adapter = self._load_adapter(prompt_adapter_request) + loaded = self._adapter_manager.add_adapter(prompt_adapter) + else: + # If the prompt_adapter is already loaded, just touch it to + # update its position in the caches + loaded = self._adapter_manager.get_adapter( + prompt_adapter_request.prompt_adapter_id) is not None + self._adapter_manager.activate_adapter( + prompt_adapter_request.prompt_adapter_id) + return loaded diff --git a/vllm/sequence.py b/vllm/sequence.py index d200115aa0921..a3f998b94d795 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -10,6 +10,7 @@ from vllm.lora.request import LoRARequest from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams if TYPE_CHECKING: @@ -238,21 +239,25 @@ class Sequence: block_size: The block size of the sequence. Should be the same as the block size used by the block manager and cache engine. lora_request: LoRA request. + prompt_adapter_request: Prompt Adapter request. + """ def __init__( - self, - seq_id: int, - inputs: "LLMInputs", - block_size: int, - eos_token_id: Optional[int] = None, - lora_request: Optional[LoRARequest] = None, + self, + seq_id: int, + inputs: "LLMInputs", + block_size: int, + eos_token_id: Optional[int] = None, + lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> None: self.seq_id = seq_id self.inputs = inputs self.block_size = block_size self.eos_token_id = eos_token_id self.lora_request = lora_request + self.prompt_adapter_request = prompt_adapter_request self.data = SequenceData(self.prompt_token_ids) self.output_logprobs: SampleLogprobs = [] @@ -287,6 +292,11 @@ def multi_modal_data(self) -> "MultiModalDataDict": def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 + @property + def prompt_adapter_id(self) -> int: + return self.prompt_adapter_request.prompt_adapter_id \ + if self.prompt_adapter_request else 0 + def get_output_text_to_return(self, buffer_length: int): # We return the full output text if the sequence is finished. truncate = buffer_length and not self.is_finished() @@ -414,6 +424,7 @@ class SequenceGroup: encoder_seq: Optional, the single encoder sequence. Should be None unless you are working with an encoder/decoder model. trace_headers: OpenTelemetry trace headers. + prompt_adapter_request: Prompt Adapter request. """ def __init__( @@ -427,6 +438,7 @@ def __init__( pooling_params: Optional[PoolingParams] = None, encoder_seq: Optional[Sequence] = None, trace_headers: Optional[Dict[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> None: self.request_id = request_id self.seqs_dict = {seq.seq_id: seq for seq in seqs} @@ -441,6 +453,7 @@ def __init__( self.state = SequenceGroupState() self.embeddings = embeddings self.pooling_params = pooling_params + self.prompt_adapter_request = prompt_adapter_request self.encoder_seq = encoder_seq self.trace_headers = trace_headers @@ -466,6 +479,16 @@ def multi_modal_data(self) -> "MultiModalDataDict": def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 + @property + def prompt_adapter_id(self) -> int: + return self.prompt_adapter_request.prompt_adapter_id \ + if self.prompt_adapter_request else 0 + + @property + def prompt_adapter_num_virtual_tokens(self) -> int: + return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\ + if self.prompt_adapter_request else 0 + def get_last_latency(self, now: float) -> Optional[float]: """Sets the last token time for Request level timings.""" # If still in prefill phase, raise Error. @@ -624,6 +647,7 @@ class SequenceGroupMetadata: (SequenceGroup.encoder_seq). Should be None unless you are working with an encoder/decoder model. + prompt_adapter_request: Prompt Adapter request. """ def __init__( @@ -642,6 +666,7 @@ def __init__( multi_modal_data: Optional["MultiModalDataDict"] = None, encoder_seq_data: Optional[SequenceData] = None, cross_block_table: Optional[List[int]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> None: self.request_id = request_id self.is_prompt = is_prompt @@ -650,6 +675,7 @@ def __init__( self.block_tables = block_tables self.pooling_params = pooling_params self.lora_request = lora_request + self.prompt_adapter_request = prompt_adapter_request self.computed_block_nums = computed_block_nums self.multi_modal_data = multi_modal_data self.state = SequenceGroupState() if state is None else state @@ -674,6 +700,16 @@ def __init__( def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 + @property + def prompt_adapter_id(self) -> int: + return self.prompt_adapter_request.prompt_adapter_id \ + if self.prompt_adapter_request else 0 + + @property + def prompt_adapter_num_virtual_tokens(self) -> int: + return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \ + if self.prompt_adapter_request else 0 + @property def token_chunk_size(self) -> int: """Return the number of tokens to be processed (chunk size).""" diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 6a2cfc819d8d2..90bba96ee8acb 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -4,7 +4,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, - SchedulerConfig) + PromptAdapterConfig, SchedulerConfig) from vllm.logger import init_logger from vllm.sequence import (IntermediateTensors, SamplerOutput, SequenceGroupMetadata) @@ -48,6 +48,7 @@ def __init__( kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, multimodal_config: Optional[MultiModalConfig] = None, + prompt_adapter_config: Optional[PromptAdapterConfig] = None, return_hidden_states: bool = False, ): if return_hidden_states: @@ -66,6 +67,7 @@ def __init__( kv_cache_dtype=kv_cache_dtype, is_driver_worker=is_driver_worker, multimodal_config=multimodal_config, + prompt_adapter_config=prompt_adapter_config, return_hidden_states=return_hidden_states, ) @@ -136,6 +138,13 @@ def execute_model( self.set_active_loras(model_input.lora_requests, model_input.lora_mapping) + if self.prompt_adapter_config: + assert model_input.prompt_adapter_requests is not None + assert model_input.prompt_adapter_mapping is not None + self.set_active_prompt_adapters( + model_input.prompt_adapter_requests, + model_input.prompt_adapter_mapping) + virtual_engine = model_input.virtual_engine outputs: List[SamplerOutput] = [] for step in range(num_steps): diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index b4277ae827c02..db0e178e45f4e 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -8,7 +8,7 @@ from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, - SchedulerConfig) + PromptAdapterConfig, SchedulerConfig) from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader import get_model @@ -81,6 +81,7 @@ def __init__( lora_config: Optional[LoRAConfig], multimodal_config: Optional[MultiModalConfig], kv_cache_dtype: Optional[str] = "auto", + prompt_adapter_config: Optional[PromptAdapterConfig] = None, is_driver_worker: bool = False, *args, **kwargs, @@ -94,6 +95,7 @@ def __init__( self.cache_config = cache_config self.lora_config = lora_config self.multimodal_config = multimodal_config + self.prompt_adapter_config = prompt_adapter_config self.load_config = load_config self.is_driver_worker = is_driver_worker diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 657505739e236..3c22c73267b7f 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -7,7 +7,7 @@ from vllm.attention import get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, - SchedulerConfig) + PromptAdapterConfig, SchedulerConfig) from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.logger import init_logger @@ -133,6 +133,7 @@ def __init__( lora_config: Optional[LoRAConfig] = None, multimodal_config: Optional[MultiModalConfig] = None, kv_cache_dtype: Optional[str] = "auto", + prompt_adapter_config: Optional[PromptAdapterConfig] = None, is_driver_worker: bool = False, ) -> None: self.model_config = model_config @@ -145,6 +146,7 @@ def __init__( self.rank = rank self.distributed_init_method = distributed_init_method self.lora_config = lora_config + self.prompt_adapter_config = prompt_adapter_config self.multimodal_config = multimodal_config self.is_driver_worker = is_driver_worker if self.is_driver_worker: @@ -167,6 +169,7 @@ def __init__( lora_config=self.lora_config, multimodal_config=self.multimodal_config, kv_cache_dtype=kv_cache_dtype, + prompt_adapter_config=self.prompt_adapter_config, is_driver_worker=is_driver_worker) # Uninitialized cache engine. Will be initialized by # initialize_cache. diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index a3b31a1c0ac8a..a333e6634a41f 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -5,7 +5,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, - SchedulerConfig) + PromptAdapterConfig, SchedulerConfig) from vllm.logger import init_logger from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.pooling_params import PoolingParams @@ -40,6 +40,7 @@ def __init__( lora_config: Optional[LoRAConfig], kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, + prompt_adapter_config: Optional[PromptAdapterConfig] = None, multimodal_config: Optional[MultiModalConfig] = None, ): super().__init__(model_config, @@ -51,6 +52,7 @@ def __init__( lora_config=lora_config, kv_cache_dtype=kv_cache_dtype, is_driver_worker=is_driver_worker, + prompt_adapter_config=prompt_adapter_config, multimodal_config=multimodal_config) @torch.inference_mode() @@ -71,6 +73,13 @@ def execute_model( self.set_active_loras(model_input.lora_requests, model_input.lora_mapping) + if self.prompt_adapter_config: + assert model_input.prompt_adapter_requests is not None + assert model_input.prompt_adapter_mapping is not None + self.set_active_prompt_adapters( + model_input.prompt_adapter_requests, + model_input.prompt_adapter_mapping) + # Currently cuda graph is only supported by the decode phase. assert model_input.attn_metadata is not None prefill_meta = model_input.attn_metadata.prefill_metadata diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index d0c82d6bbedf3..205b4f58f7a83 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -25,7 +25,7 @@ from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, - SchedulerConfig) + PromptAdapterConfig, SchedulerConfig) from vllm.distributed import get_pp_group from vllm.distributed.parallel_state import graph_capture from vllm.inputs import INPUT_REGISTRY @@ -40,6 +40,10 @@ supports_vision) from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors, MultiModalInputs) +from vllm.prompt_adapter.layers import PromptAdapterMapping +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.prompt_adapter.worker_manager import ( + LRUCacheWorkerPromptAdapterManager) from vllm.sampling_params import SamplingParams from vllm.sequence import (IntermediateTensors, SamplerOutput, SequenceGroupMetadata) @@ -85,6 +89,8 @@ class ModelInputForGPU(ModelRunnerInputBase): lora_mapping: Optional["LoRAMapping"] = None lora_requests: Optional[Set[LoRARequest]] = None attn_metadata: Optional["AttentionMetadata"] = None + prompt_adapter_mapping: Optional[PromptAdapterMapping] = None + prompt_adapter_requests: Optional[Set[PromptAdapterRequest]] = None multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None finished_requests_ids: Optional[List[str]] = None @@ -97,6 +103,8 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: "lora_requests": self.lora_requests, "lora_mapping": self.lora_mapping, "multi_modal_kwargs": self.multi_modal_kwargs, + "prompt_adapter_mapping": self.prompt_adapter_mapping, + "prompt_adapter_requests": self.prompt_adapter_requests, "virtual_engine": self.virtual_engine, "request_ids_to_seq_ids": self.request_ids_to_seq_ids, "finished_requests_ids": self.finished_requests_ids, @@ -133,6 +141,8 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: "lora_requests": self.lora_requests, "lora_mapping": self.lora_mapping, "multi_modal_kwargs": self.multi_modal_kwargs, + "prompt_adapter_mapping": self.prompt_adapter_mapping, + "prompt_adapter_requests": self.prompt_adapter_requests, "virtual_engine": self.virtual_engine, "request_ids_to_seq_ids": self.request_ids_to_seq_ids, "finished_requests_ids": self.finished_requests_ids, @@ -172,6 +182,7 @@ def __init__( lora_config: Optional[LoRAConfig], kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, + prompt_adapter_config: Optional[PromptAdapterConfig] = None, multimodal_config: Optional[MultiModalConfig] = None, return_hidden_states: bool = False, ): @@ -183,6 +194,7 @@ def __init__( self.lora_config = lora_config self.load_config = load_config self.is_driver_worker = is_driver_worker + self.prompt_adapter_config = prompt_adapter_config self.multimodal_config = multimodal_config self.return_hidden_states = return_hidden_states @@ -232,6 +244,7 @@ def __init__( self.model: nn.Module # Set after load_model # Set after load_model. self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None + self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None self.flashinfer_decode_workspace_buffer = None self.flashinfer_decode_wrapper = None @@ -240,16 +253,14 @@ def __init__( def load_model(self) -> None: with CudaMemoryProfiler() as m: - self.model = get_model( - model_config=self.model_config, - device_config=self.device_config, - load_config=self.load_config, - lora_config=self.lora_config, - multimodal_config=self.multimodal_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - cache_config=self.cache_config, - ) + self.model = get_model(model_config=self.model_config, + device_config=self.device_config, + load_config=self.load_config, + lora_config=self.lora_config, + multimodal_config=self.multimodal_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + cache_config=self.cache_config) self.model_memory_usage = m.consumed_memory logger.info("Loading model weights took %.4f GB", @@ -274,6 +285,15 @@ def load_model(self) -> None: ) self.model = self.lora_manager.create_lora_manager(self.model) + if self.prompt_adapter_config: + self.prompt_adapter_manager = LRUCacheWorkerPromptAdapterManager( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens, self.device, + self.prompt_adapter_config) + self.model = ( + self.prompt_adapter_manager.create_prompt_adapter_manager( + self.model)) + if self.kv_cache_dtype == "fp8" and is_hip(): # Currently only ROCm accepts kv-cache scaling factors # via quantization_param_path and this will be deprecated @@ -354,6 +374,9 @@ def _prepare_model_input_tensors( lora_index_mapping: List[int] = [] lora_prompt_mapping: List[int] = [] lora_requests: Set[LoRARequest] = set() + prompt_adapter_index_mapping: List[int] = [] + prompt_adapter_prompt_mapping: List[int] = [] + prompt_adapter_requests: Set[PromptAdapterRequest] = set() seq_lens: List[int] = [] prefill_seq_lens: List[int] = [] @@ -504,6 +527,7 @@ def _prepare_model_input_tensors( input_tokens.extend(tokens) input_positions.extend(list(range(context_len, seq_len))) lora_id = seq_group_metadata.lora_int_id + prompt_adapter_id = seq_group_metadata.prompt_adapter_id if is_prompt: assert len(seq_ids) == 1 @@ -534,6 +558,21 @@ def _prepare_model_input_tensors( mm_kwargs = self.multi_modal_input_mapper(mm_data) multi_modal_inputs_list.append(mm_kwargs) + if prompt_adapter_id > 0 and is_prompt: + prompt_adapter_requests.add( + seq_group_metadata.prompt_adapter_request) + + num_tokens = seq_group_metadata.\ + prompt_adapter_num_virtual_tokens + pm = [prompt_adapter_id + ] * num_tokens + [0] * (query_len - num_tokens) + prompt_adapter_index_mapping += pm + prompt_adapter_prompt_mapping.extend( + [prompt_adapter_id] * + (query_len if seq_group_metadata.sampling_params + and seq_group_metadata.sampling_params.prompt_logprobs + else 1)) + is_profile_run = _is_block_tables_empty( seq_group_metadata.block_tables) if is_profile_run: @@ -618,12 +657,11 @@ def _prepare_model_input_tensors( seq_lens.append(1) block_tables.append([]) lora_index_mapping.append(0) - + prompt_adapter_index_mapping.append(0) if self.attn_backend.get_name() == "flashinfer": last_paged_kv_indptr = paged_kv_indptr[-1] paged_kv_indptr.append(last_paged_kv_indptr) paged_kv_last_page_len.append(0) - batch_size = graph_batch_size num_decode_tokens = batch_size @@ -759,6 +797,14 @@ def _prepare_model_input_tensors( else: lora_mapping = None + if self.prompt_adapter_config: + prompt_adapter_mapping = PromptAdapterMapping( + prompt_adapter_index_mapping, + prompt_adapter_prompt_mapping, + ) + else: + prompt_adapter_mapping = None + multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list, device=self.device) request_ids_to_seq_ids = { @@ -776,7 +822,10 @@ def _prepare_model_input_tensors( lora_requests=lora_requests, multi_modal_kwargs=multi_modal_kwargs, request_ids_to_seq_ids=request_ids_to_seq_ids, - finished_requests_ids=finished_requests_ids) + finished_requests_ids=finished_requests_ids, + prompt_adapter_mapping=prompt_adapter_mapping, + prompt_adapter_requests=prompt_adapter_requests, + ) @torch.inference_mode() def profile_run(self) -> None: @@ -878,33 +927,67 @@ def profile_run(self) -> None: def remove_all_loras(self): if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") - self.lora_manager.remove_all_loras() + self.lora_manager.remove_all_adapters() def set_active_loras(self, lora_requests: Set[LoRARequest], lora_mapping: LoRAMapping) -> None: if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") - self.lora_manager.set_active_loras(lora_requests, lora_mapping) + self.lora_manager.set_active_adapters(lora_requests, lora_mapping) def add_lora(self, lora_request: LoRARequest) -> bool: if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.add_lora(lora_request) + return self.lora_manager.add_adapter(lora_request) def remove_lora(self, lora_id: int) -> bool: if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.remove_lora(lora_id) + return self.lora_manager.remove_adapter(lora_id) def pin_lora(self, lora_id: int) -> bool: if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.pin_lora(lora_id) + return self.lora_manager.pin_adapter(lora_id) def list_loras(self) -> Set[int]: if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.list_loras() + return self.lora_manager.list_adapters() + + def remove_all_prompt_adapters(self): + if not self.prompt_adapter_manager: + raise RuntimeError("PromptAdapter is not enabled.") + self.prompt_adapter_manager.remove_all_adapters() + + def set_active_prompt_adapters( + self, prompt_adapter_requests: Set[PromptAdapterRequest], + prompt_adapter_mapping: PromptAdapterMapping) -> None: + if not self.prompt_adapter_manager: + raise RuntimeError("PromptAdapter is not enabled.") + self.prompt_adapter_manager.set_active_adapters( + prompt_adapter_requests, prompt_adapter_mapping) + + def add_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + if not self.prompt_adapter_manager: + raise RuntimeError("PromptAdapter is not enabled.") + return self.prompt_adapter_manager.add_adapter(prompt_adapter_request) + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + if not self.prompt_adapter_manager: + raise RuntimeError("PromptAdapter is not enabled.") + return self.prompt_adapter_manager.remove_adapter(prompt_adapter_id) + + def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: + if not self.prompt_adapter_manager: + raise RuntimeError("PromptAdapter is not enabled.") + return self.prompt_adapter_manager.pin_adapter(prompt_adapter_id) + + def list_prompt_adapters(self) -> Set[int]: + if not self.prompt_adapter_manager: + raise RuntimeError("PromptAdapter is not enabled.") + return self.prompt_adapter_manager.list_adapters() @torch.inference_mode() def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: @@ -1063,6 +1146,14 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: ) self.set_active_loras(set(), lora_mapping) + if self.prompt_adapter_config: + prompt_adapter_mapping = PromptAdapterMapping( + [-1] * batch_size, + [-1] * batch_size, + ) + self.set_active_prompt_adapters( + set(), prompt_adapter_mapping) + graph_runner = CUDAGraphRunner( self.model, self.attn_backend.get_name()) @@ -1189,6 +1280,13 @@ def execute_model( self.set_active_loras(model_input.lora_requests, model_input.lora_mapping) + if self.prompt_adapter_config: + assert model_input.prompt_adapter_requests is not None + assert model_input.prompt_adapter_mapping is not None + self.set_active_prompt_adapters( + model_input.prompt_adapter_requests, + model_input.prompt_adapter_mapping) + if self.attn_backend.get_name() == "flashinfer": assert model_input.attn_metadata is not None assert model_input.input_tokens is not None diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 58707269bd68c..857cd86beff92 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -8,7 +8,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, - SchedulerConfig, SpeculativeConfig) + PromptAdapterConfig, SchedulerConfig, + SpeculativeConfig) from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) @@ -16,6 +17,7 @@ from vllm.model_executor import set_random_seed from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.platforms import current_platform +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import ExecuteModelRequest from vllm.worker.cache_engine import CacheEngine from vllm.worker.embedding_model_runner import EmbeddingModelRunner @@ -45,6 +47,7 @@ def __init__( lora_config: Optional[LoRAConfig] = None, multimodal_config: Optional[MultiModalConfig] = None, speculative_config: Optional[SpeculativeConfig] = None, + prompt_adapter_config: Optional[PromptAdapterConfig] = None, is_driver_worker: bool = False, model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None, ) -> None: @@ -59,6 +62,7 @@ def __init__( self.distributed_init_method = distributed_init_method self.lora_config = lora_config self.load_config = load_config + self.prompt_adapter_config = prompt_adapter_config self.is_driver_worker = is_driver_worker if parallel_config and is_driver_worker: assert rank % parallel_config.tensor_parallel_size == 0, \ @@ -92,6 +96,7 @@ def __init__( lora_config=self.lora_config, kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=is_driver_worker, + prompt_adapter_config=prompt_adapter_config, multimodal_config=multimodal_config, **speculative_args, ) @@ -296,6 +301,19 @@ def pin_lora(self, lora_id: int) -> bool: def list_loras(self) -> Set[int]: return self.model_runner.list_loras() + def add_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + return self.model_runner.add_prompt_adapter(prompt_adapter_request) + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + return self.model_runner.remove_lora(prompt_adapter_id) + + def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: + return self.model_runner.pin_prompt_adapter(prompt_adapter_id) + + def list_prompt_adapters(self) -> Set[int]: + return self.model_runner.list_prompt_adapters() + @property def max_model_len(self) -> int: return self.model_config.max_model_len diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 03b9cce5ae792..e03f24fdfc41a 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -8,7 +8,7 @@ from vllm.attention import get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, - SchedulerConfig) + PromptAdapterConfig, SchedulerConfig) from vllm.distributed import broadcast_tensor_dict from vllm.inputs import INPUT_REGISTRY from vllm.logger import init_logger @@ -88,6 +88,7 @@ def __init__( lora_config: Optional[LoRAConfig], multimodal_config: Optional[MultiModalConfig], kv_cache_dtype: Optional[str] = "auto", + prompt_adapter_config: Optional[PromptAdapterConfig] = None, is_driver_worker: bool = False, *args, **kwargs, @@ -98,6 +99,7 @@ def __init__( self.lora_config = lora_config self.load_config = load_config self.cache_config = cache_config + self.prompt_adapter_config = prompt_adapter_config self.multimodal_config = multimodal_config self.is_driver_worker = is_driver_worker diff --git a/vllm/worker/xpu_worker.py b/vllm/worker/xpu_worker.py index 94dfcfec37757..6a822c2ba3e7a 100644 --- a/vllm/worker/xpu_worker.py +++ b/vllm/worker/xpu_worker.py @@ -10,7 +10,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, - SchedulerConfig, SpeculativeConfig) + PromptAdapterConfig, SchedulerConfig, + SpeculativeConfig) from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.logger import init_logger @@ -47,6 +48,7 @@ def __init__( lora_config: Optional[LoRAConfig] = None, multimodal_config: Optional[MultiModalConfig] = None, speculative_config: Optional[SpeculativeConfig] = None, + prompt_adapter_config: Optional[PromptAdapterConfig] = None, is_driver_worker: bool = False, ) -> None: assert device_config.device_type == "xpu" @@ -63,6 +65,7 @@ def __init__( self.rank = rank self.distributed_init_method = distributed_init_method self.lora_config = lora_config + self.prompt_adapter_config = prompt_adapter_config self.is_driver_worker = is_driver_worker if self.is_driver_worker: assert self.rank == 0, "The driver worker must have rank 0." From 673dd4cae9340e78dd5c05843e41c38133aa29a6 Mon Sep 17 00:00:00 2001 From: Murali Andoorveedu <37849411+andoorve@users.noreply.github.com> Date: Tue, 9 Jul 2024 16:24:58 -0700 Subject: [PATCH 03/26] [Docs] Docs update for Pipeline Parallel (#6222) Signed-off-by: Muralidhar Andoorveedu Co-authored-by: Simon Mo --- docs/source/serving/distributed_serving.rst | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/docs/source/serving/distributed_serving.rst b/docs/source/serving/distributed_serving.rst index 91f64ad2e9515..3c58ed295fba6 100644 --- a/docs/source/serving/distributed_serving.rst +++ b/docs/source/serving/distributed_serving.rst @@ -3,7 +3,7 @@ Distributed Inference and Serving ================================= -vLLM supports distributed tensor-parallel inference and serving. Currently, we support `Megatron-LM's tensor parallel algorithm `_. We manage the distributed runtime with either `Ray `_ or python native multiprocessing. Multiprocessing can be used when deploying on a single node, multi-node inferencing currently requires Ray. +vLLM supports distributed tensor-parallel inference and serving. Currently, we support `Megatron-LM's tensor parallel algorithm `_. We also support pipeline parallel as a beta feature for online serving. We manage the distributed runtime with either `Ray `_ or python native multiprocessing. Multiprocessing can be used when deploying on a single node, multi-node inferencing currently requires Ray. Multiprocessing will be used by default when not running in a Ray placement group and if there are sufficient GPUs available on the same node for the configured :code:`tensor_parallel_size`, otherwise Ray will be used. This default can be overridden via the :code:`LLM` class :code:`distributed-executor-backend` argument or :code:`--distributed-executor-backend` API server argument. Set it to :code:`mp` for multiprocessing or :code:`ray` for Ray. It's not required for Ray to be installed for the multiprocessing case. @@ -23,6 +23,19 @@ To run multi-GPU serving, pass in the :code:`--tensor-parallel-size` argument wh $ --model facebook/opt-13b \ $ --tensor-parallel-size 4 +You can also additionally specify :code:`--pipeline-parallel-size` to enable pipeline parallelism. For example, to run API server on 8 GPUs with pipeline parallelism and tensor parallelism: + +.. code-block:: console + + $ python -m vllm.entrypoints.openai.api_server \ + $ --model gpt2 \ + $ --tensor-parallel-size 4 \ + $ --pipeline-parallel-size 2 \ + $ --distributed-executor-backend ray + +.. note:: + Pipeline parallel is a beta feature. It is only supported for online serving and the ray backend for now, as well as LLaMa and GPT2 style models. + To scale vLLM beyond a single machine, install and start a `Ray runtime `_ via CLI before running vLLM: .. code-block:: console @@ -35,7 +48,7 @@ To scale vLLM beyond a single machine, install and start a `Ray runtime -After that, you can run inference and serving on multiple machines by launching the vLLM process on the head node by setting :code:`tensor_parallel_size` to the number of GPUs to be the total number of GPUs across all machines. +After that, you can run inference and serving on multiple machines by launching the vLLM process on the head node by setting :code:`tensor_parallel_size` multiplied by :code:`pipeline_parallel_size` to the number of GPUs to be the total number of GPUs across all machines. .. warning:: Please make sure you downloaded the model to all the nodes, or the model is downloaded to some distributed file system that is accessible by all nodes. From d3a245138acb358c7e1e5c5dcf4dcb3c2b48c8ff Mon Sep 17 00:00:00 2001 From: Baoyuan Qi Date: Wed, 10 Jul 2024 07:43:24 +0800 Subject: [PATCH 04/26] [Bugfix]fix and needs_scalar_to_array logic check (#6238) Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> --- vllm/model_executor/layers/linear.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 3cc257834033a..1dda5d3740a8b 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -387,7 +387,7 @@ def weight_loader(self, if loaded_shard_id is None: # Loaded weight is already fused on disk (qkv/mlp). if output_dim is None: - if needs_scalar_to_array is not None: + if needs_scalar_to_array: param_data, loaded_weight = adjust_scalar_to_fused_array( param_data, loaded_weight, 0) @@ -549,7 +549,7 @@ def weight_loader(self, if loaded_shard_id is None: # Loaded weight is already fused on disk (qkv/mlp). if output_dim is None: - if needs_scalar_to_array is not None: + if needs_scalar_to_array: param_data, loaded_weight = adjust_scalar_to_fused_array( param_data, loaded_weight, 0) From 2416b26e119b9d1932ba30790ecaddfac1ae4143 Mon Sep 17 00:00:00 2001 From: Abhinav Goyal Date: Wed, 10 Jul 2024 07:04:02 +0530 Subject: [PATCH 05/26] [Speculative Decoding] Medusa Implementation with Top-1 proposer (#4978) --- .../e2e/test_medusa_correctness.py | 226 ++++++++++++++++++ vllm/model_executor/models/__init__.py | 1 + vllm/model_executor/models/medusa.py | 159 ++++++++++++ vllm/spec_decode/medusa_worker.py | 127 ++++++++++ vllm/spec_decode/spec_decode_worker.py | 5 + vllm/transformers_utils/config.py | 6 +- vllm/transformers_utils/configs/__init__.py | 2 + vllm/transformers_utils/configs/medusa.py | 60 +++++ vllm/worker/worker.py | 5 +- 9 files changed, 587 insertions(+), 4 deletions(-) create mode 100644 tests/spec_decode/e2e/test_medusa_correctness.py create mode 100644 vllm/model_executor/models/medusa.py create mode 100644 vllm/spec_decode/medusa_worker.py create mode 100644 vllm/transformers_utils/configs/medusa.py diff --git a/tests/spec_decode/e2e/test_medusa_correctness.py b/tests/spec_decode/e2e/test_medusa_correctness.py new file mode 100644 index 0000000000000..7e4a6cc62d02b --- /dev/null +++ b/tests/spec_decode/e2e/test_medusa_correctness.py @@ -0,0 +1,226 @@ +"""This docstring details important information on the testing methodology. + +Most of the tests rely on "greedy equality", where we expect the output of +speculative decoding on a sequence to exactly match the output of normal non- +speculative decoding. + +Since speculative decoding with rejection sampling guarantees that the output +distribution matches the target model's output distribution (up to hardware +numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy +equality. + +However, we still need to verify below scenario could be passed: + * Batch size 1 greedy equality + * Batch size >1 greedy equality + * Test greedy equality under preemption + * Test greedy equality under various number of speculative tokens. + +With those tests, we can say at least, Medusa would not break the +correctess for the target model outputs. +""" + +import pytest + +from .conftest import run_greedy_equality_correctness_test + +# main model +# lmsys/vicuna-7b-v1.3 was to be used but it's causing +# OOM in CI pipeline, so using a smaller model. +MAIN_MODEL = "JackFram/llama-68m" + +# speculative model +SPEC_MODEL = "abhigoyal/vllm-medusa-llama-68m-random" + +# max. number of speculative tokens: this corresponds to +# num_heads in the config.json of the speculator model. +MAX_SPEC_TOKENS = 5 + +# precision +PRECISION = "float32" + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": PRECISION, + + # Main model + "model": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, +]) +@pytest.mark.parametrize("output_len", [ + 128, +]) +@pytest.mark.parametrize("batch_size", [1, 32]) +@pytest.mark.parametrize("seed", [1]) +def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify greedy equality with different batch size.""" + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "block_size": 8, + # 2 for small prompt, 256//8 for generated. + "num_gpu_blocks_override": 2 + 256 // 8, + "max_model_len": (2 + 256 // 8) * 8, + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Precision + "dtype": PRECISION, + + # Main model + "model": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, +]) +@pytest.mark.parametrize( + "output_len", + [ + # Use small output len for fast test. + 128, + ]) +@pytest.mark.parametrize("batch_size", [4]) +@pytest.mark.parametrize("seed", [1]) +def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator, + test_llm_generator, + batch_size: int, + output_len: int): + """Verify greedy equality, even when some sequences are preempted mid- + generation. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Precision + "dtype": PRECISION, + + # Main model + "model": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize( + "test_llm_kwargs", + [ + { + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": k, + } + # Try a range of num. speculative tokens + for k in range(1, 1 + MAX_SPEC_TOKENS) + ]) +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_mlp_different_k(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify that mlp speculative decoding produces exact equality + to without spec decode with different values of num_speculative_tokens. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Precision + "dtype": PRECISION, + + # Main model + "model": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", + [{ + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + "speculative_disable_by_batch_size": 4 + }]) +@pytest.mark.parametrize("batch_size", [1, 5]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_mlp_disable_queue(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify that mlp speculative decoding produces exact equality + to without spec decode when speculation is disabled for large + batch sizes. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +if __name__ == "__main__": + import pytest + pytest.main([__file__]) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 644b95aae3656..096e3f4724014 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -64,6 +64,7 @@ "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"), "XverseForCausalLM": ("xverse", "XverseForCausalLM"), "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), + "MedusaModel": ("medusa", "Medusa"), "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), "JambaForCausalLM": ("jamba", "JambaForCausalLM") } diff --git a/vllm/model_executor/models/medusa.py b/vllm/model_executor/models/medusa.py new file mode 100644 index 0000000000000..6453d0cb25c91 --- /dev/null +++ b/vllm/model_executor/models/medusa.py @@ -0,0 +1,159 @@ +from typing import Iterable, List, Optional, Tuple + +import torch +import torch.nn as nn + +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import SamplerOutput +from vllm.transformers_utils.configs.medusa import MedusaConfig + + +class ResidualBlock(nn.Module): + + def __init__(self, hidden_size: int, num_layers: int) -> None: + super().__init__() + + self.layers = nn.ModuleList([ + nn.Linear(hidden_size, hidden_size, bias=False) + for _ in range(num_layers) + ]) + self.act = nn.SiLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for layer in self.layers: + x = x + self.act(layer(x)) + return x + + +class Medusa(nn.Module): + + def __init__(self, config: MedusaConfig, **_) -> None: + super().__init__() + self.config = config + self.blocks = nn.ModuleList([ + ResidualBlock(hidden_size=self.config.hidden_size, + num_layers=self.config.num_hidden_layers) + for _ in range(self.config.num_heads) + ]) + self.orig_vocab_size = config.vocab_size + self.truncated_vocab_size = config.truncated_vocab_size + self.unpadded_vocab_size = self.truncated_vocab_size + + self.lm_heads = nn.ModuleList([ + ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=self.truncated_vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + ) for _ in range(self.config.num_heads) + ]) + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + self.truncated_vocab_size, + logit_scale) + + self.token_map = None + + def forward(self, hidden_states: torch.Tensor) -> List[torch.Tensor]: + return [block(hidden_states) for block in self.blocks] + + def compute_logits( + self, hidden_states: List[torch.Tensor], + sampling_metadata: SamplingMetadata) -> List[torch.Tensor]: + logits = [] + + for hs, lm_head in zip(hidden_states, self.lm_heads): + _logits = self.logits_processor(lm_head, hs, sampling_metadata) + + if self.token_map is None: + logits.append(_logits) + else: + logits.append(-torch.inf * torch.ones( + size=(*_logits.shape[:-1], self.orig_vocab_size), + device=_logits.device, + dtype=_logits.dtype)) + + logits[-1][..., self.token_map] = _logits + + return logits + + def sample( + self, + logits: List[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> List[SamplerOutput]: + logits = torch.stack(logits, dim=0).float() + logprobs = torch.log_softmax(logits, dim=-1) + token_ids = logits.argmax(-1) # support only top-1 for now + probs = torch.softmax(logits, dim=-1) + + token_id_list = [] + token_prob_list = [] + token_logprob_list = [] + + for idx, seq_group in enumerate(sampling_metadata.seq_groups): + token_id_list.append(token_ids[:, seq_group.sample_indices]) + token_prob_list.append(probs[:, seq_group.sample_indices]) + token_logprob_list.append(logprobs[:, seq_group.sample_indices]) + + outputs: List[Optional[SamplerOutput]] = [] + for idx in range(len(sampling_metadata.seq_groups)): + outputs.append( + SamplerOutput( + outputs=None, + sampled_token_probs=token_prob_list[idx].squeeze(1), + logprobs=token_logprob_list[idx].squeeze(1), + sampled_token_ids=token_id_list[idx].squeeze(1), + )) + + return outputs + + def generate_proposals( + self, + previous_hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> List[SamplerOutput]: + return self.sample( + logits=self.compute_logits( + hidden_states=self.forward(previous_hidden_states), + sampling_metadata=sampling_metadata, + ), + sampling_metadata=sampling_metadata, + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + params_dict = dict(self.named_parameters()) + + weights_map = {} + + for name, loaded_weight in weights: + name = name.replace("medusa_heads.", "") + + if name == "token_map": + if self.truncated_vocab_size < self.orig_vocab_size: + self.token_map = nn.Parameter(loaded_weight, + requires_grad=False) + elif name in params_dict: + weights_map[name] = loaded_weight + + for name, loaded_weight in weights_map.items(): + if "lm_head" in name and self.token_map is not None and\ + loaded_weight.shape[0] > self.token_map.shape[0]: + + loaded_weight = loaded_weight[self.token_map] + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + if self.token_map is not None: + self.token_map.to(device=self.lm_heads[0].weight.device) + + assert (self.truncated_vocab_size + == self.orig_vocab_size) or (self.token_map is not None) diff --git a/vllm/spec_decode/medusa_worker.py b/vllm/spec_decode/medusa_worker.py new file mode 100644 index 0000000000000..b72740fc3961c --- /dev/null +++ b/vllm/spec_decode/medusa_worker.py @@ -0,0 +1,127 @@ +import weakref +from typing import List, Optional, Tuple + +import torch + +from vllm.model_executor import SamplingMetadata +from vllm.sequence import (ExecuteModelRequest, SamplerOutput, + SequenceGroupMetadata) +from vllm.spec_decode.interfaces import SpeculativeProposals +from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase +from vllm.spec_decode.top1_proposer import Top1Proposer +from vllm.worker.worker import Worker + + +class MedusaWorker(NonLLMProposerWorkerBase, Worker): + """Worker for Medusa. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Lazy initialization list. + self._proposer: Top1Proposer + + def init_device(self): + super().init_device() + + self._proposer = Top1Proposer( + weakref.proxy(self), # type: ignore[arg-type] + self.device, + self.vocab_size, + max_proposal_len=self.max_model_len, + ) + + def set_include_gpu_probs_tensor(self): + pass + + @torch.inference_mode() + def sampler_output( + self, + execute_model_req: ExecuteModelRequest, + sample_len: int, + ) -> Tuple[List[SamplerOutput], bool]: + """Run the model forward pass to generate sample_len future tokens. + Returns the list of sampler output, one per layer, along with indicator + of whether torch tensor in sampler output need to be transposed in + latter sampler_output_to_torch logic. + + For medusa worker, this indicator shall be False. + """ + self._raise_if_unsupported(execute_model_req) + + seq_group_metadata_list = execute_model_req.seq_group_metadata_list + + seq_lens, query_lens = self._prepare_input_tensors( + seq_group_metadata_list) + + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, seq_lens, query_lens, self.device, + self.model_runner.pin_memory) + + model_outputs = self.model_runner.model.generate_proposals( + previous_hidden_states=execute_model_req.previous_hidden_states. + hidden_states, + sampling_metadata=sampling_metadata) + + return model_outputs, False + + def _prepare_input_tensors( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + ) -> Tuple[List[int], List[int]]: + if not seq_group_metadata_list: + return [], [] + + seq_lens: List[int] = [] + query_lens: List[int] = [] + + for seq_group_metadata in seq_group_metadata_list: + is_prompt = seq_group_metadata.is_prompt + + for seq_data in seq_group_metadata.seq_data.values(): + seq_data_len = seq_data.get_len() + if is_prompt: + context_len = seq_data.get_num_computed_tokens() + seq_len = min( + seq_data_len, + context_len + seq_group_metadata.token_chunk_size) + seq_lens.append(seq_len) + query_lens.append(seq_len - context_len) + else: + seq_lens.append(seq_data_len) + query_lens.append(1) + + return seq_lens, query_lens + + def get_spec_proposals( + self, + execute_model_req: ExecuteModelRequest, + ) -> SpeculativeProposals: + """Produce speculations given an input batch of sequences. The number of + speculative tokens per sequence is determined by max_proposal_len. + """ + + return self._proposer.get_spec_proposals(execute_model_req) + + def _raise_if_unsupported( + self, + execute_model_req: ExecuteModelRequest, + ) -> None: + """MedusaWorker does not yet implement support for cache swap + operations or beam search. + """ + if any([ + execute_model_req.blocks_to_swap_in, + execute_model_req.blocks_to_swap_out, + execute_model_req.blocks_to_copy + ]): + raise NotImplementedError( + "MedusaWorker does not support cache operations") + + if any( + len(seq_group_metadata.seq_data.keys()) != 1 + for seq_group_metadata in + execute_model_req.seq_group_metadata_list): + raise NotImplementedError( + "MedusaWorker does not support beam search.") diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 43ce987de1e16..60a7dab68b7fd 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -18,6 +18,7 @@ from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) +from vllm.spec_decode.medusa_worker import MedusaWorker from vllm.spec_decode.metrics import AsyncMetricsCollector from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker from vllm.spec_decode.multi_step_worker import MultiStepWorker @@ -129,6 +130,10 @@ def create_worker( "model_config"].hf_config.model_type == "mlp_speculator": disable_bonus_tokens = False proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs) + elif draft_worker_kwargs[ + "model_config"].hf_config.model_type == "medusa": + disable_bonus_tokens = False + proposer_worker = MedusaWorker(**draft_worker_kwargs) else: if draft_tp == 1: draft_worker_kwargs[ diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 5e2fe116db9c6..652505a892142 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -6,8 +6,9 @@ from vllm.envs import VLLM_USE_MODELSCOPE from vllm.logger import init_logger from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, - JAISConfig, MLPSpeculatorConfig, - MPTConfig, RWConfig) + JAISConfig, MedusaConfig, + MLPSpeculatorConfig, MPTConfig, + RWConfig) if VLLM_USE_MODELSCOPE: from modelscope import AutoConfig @@ -24,6 +25,7 @@ "RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct) "jais": JAISConfig, "mlp_speculator": MLPSpeculatorConfig, + "medusa": MedusaConfig, } for name, cls in _CONFIG_REGISTRY.items(): diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index d8170858c2a9a..51de11ca3e42a 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -5,6 +5,7 @@ # `FalconConfig` class from the official HuggingFace transformers library. from vllm.transformers_utils.configs.falcon import RWConfig from vllm.transformers_utils.configs.jais import JAISConfig +from vllm.transformers_utils.configs.medusa import MedusaConfig from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig from vllm.transformers_utils.configs.mpt import MPTConfig @@ -14,5 +15,6 @@ "MPTConfig", "RWConfig", "JAISConfig", + "MedusaConfig", "MLPSpeculatorConfig", ] diff --git a/vllm/transformers_utils/configs/medusa.py b/vllm/transformers_utils/configs/medusa.py new file mode 100644 index 0000000000000..d71a08343be2a --- /dev/null +++ b/vllm/transformers_utils/configs/medusa.py @@ -0,0 +1,60 @@ +import os +from typing import Optional, Union + +from transformers import PretrainedConfig + + +class MedusaConfig(PretrainedConfig): + model_type = "medusa" + + def __init__(self, + hidden_size: int = 4096, + vocab_size: int = 32001, + num_heads: int = 5, + num_hidden_layers: int = 1, + max_paths: int = 64, + topk: int = 10, + truncated_vocab_size: Optional[int] = None, + **kwargs): + + self.hidden_size = hidden_size + self.vocab_size = vocab_size + self.num_heads = num_heads + self.num_hidden_layers = num_hidden_layers + self.max_paths = max_paths + self.topk = topk + self.max_seq_len = int(2**20) + self.truncated_vocab_size = vocab_size if truncated_vocab_size is None\ + else truncated_vocab_size + if "architectures" not in kwargs: + kwargs["architectures"] = ["MedusaModel"] + + super().__init__(**kwargs) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Union[str, os.PathLike], + **kwargs, + ) -> "MedusaConfig": + config_dict, kwargs = cls.get_config_dict( + pretrained_model_name_or_path, **kwargs) + for k in list(config_dict.keys()): + if 'num' in k: + if 'heads' in k: + config_dict["num_heads"] = config_dict.pop(k) + elif 'layers' in k: + config_dict["num_hidden_layers"] = config_dict.pop(k) + return cls.from_dict(config_dict, **kwargs) + + @property + def num_attention_heads(self): + return 0 + + @property + def num_lookahead_tokens(self): + return self.num_heads + + @num_lookahead_tokens.setter + def num_lookahead_tokens(self, num_lookahead_tokens: int): + self.num_heads = num_lookahead_tokens diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 857cd86beff92..56d8587f8f010 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -78,8 +78,9 @@ def __init__( speculative_args = {} if speculative_config is None \ or (speculative_config.draft_model_config.model == model_config.model) \ - or (speculative_config.draft_model_config.hf_config.model_type != - "mlp_speculator") else {"return_hidden_states": True} + or (speculative_config.draft_model_config.hf_config.model_type + not in ["medusa", "mlp_speculator"]) \ + else {"return_hidden_states": True} ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner if model_runner_cls is not None: From da78caecfa7f6137efc3e08388f4db102650ac45 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 9 Jul 2024 18:49:11 -0700 Subject: [PATCH 06/26] [core][distributed] zmq fallback for broadcasting large objects (#6183) [core][distributed] add zmq fallback for broadcasting large objects (#6183) --- requirements-common.txt | 1 + tests/distributed/test_same_node.py | 5 +- tests/distributed/test_shm_broadcast.py | 17 +- .../device_communicators/custom_all_reduce.py | 4 +- .../device_communicators/shm_broadcast.py | 269 +++++++++++++++--- vllm/distributed/parallel_state.py | 58 ++-- 6 files changed, 274 insertions(+), 80 deletions(-) diff --git a/requirements-common.txt b/requirements-common.txt index 765568b033bfb..e874c4af49d66 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -21,3 +21,4 @@ lm-format-enforcer == 0.10.1 outlines >= 0.0.43 # Requires torch >= 2.1.0 typing_extensions filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 +pyzmq diff --git a/tests/distributed/test_same_node.py b/tests/distributed/test_same_node.py index 4880bab79069c..2d886eb566d5d 100644 --- a/tests/distributed/test_same_node.py +++ b/tests/distributed/test_same_node.py @@ -2,10 +2,11 @@ import torch -from vllm.distributed.parallel_state import is_in_the_same_node +from vllm.distributed.parallel_state import in_the_same_node_as torch.distributed.init_process_group(backend="gloo") -test_result = is_in_the_same_node(torch.distributed.group.WORLD) +test_result = all( + in_the_same_node_as(torch.distributed.group.WORLD, source_rank=0)) expected = os.environ.get("VLLM_TEST_SAME_HOST", "1") == "1" assert test_result == expected, f"Expected {expected}, got {test_result}" diff --git a/tests/distributed/test_shm_broadcast.py b/tests/distributed/test_shm_broadcast.py index 2c2466f81bb8a..2761b7f6c0644 100644 --- a/tests/distributed/test_shm_broadcast.py +++ b/tests/distributed/test_shm_broadcast.py @@ -6,8 +6,7 @@ import numpy as np import torch.distributed as dist -from vllm.distributed.device_communicators.shm_broadcast import ( - ShmRingBuffer, ShmRingBufferIO) +from vllm.distributed.device_communicators.shm_broadcast import MessageQueue from vllm.utils import update_environment_variables @@ -56,8 +55,8 @@ def wrapped_fn(env): @worker_fn_wrapper def worker_fn(): writer_rank = 2 - broadcaster = ShmRingBufferIO.create_from_process_group( - dist.group.WORLD, 1024 * 1024, 2, writer_rank) + broadcaster = MessageQueue.create_from_process_group( + dist.group.WORLD, 40 * 1024, 2, writer_rank) if dist.get_rank() == writer_rank: seed = random.randint(0, 1000) dist.broadcast_object_list([seed], writer_rank) @@ -87,13 +86,3 @@ def worker_fn(): def test_shm_broadcast(): distributed_run(worker_fn, 4) - - -def test_singe_process(): - buffer = ShmRingBuffer(1, 1024, 4) - reader = ShmRingBufferIO(buffer, reader_rank=0) - writer = ShmRingBufferIO(buffer, reader_rank=-1) - writer.enqueue([0]) - writer.enqueue([1]) - assert reader.dequeue() == [0] - assert reader.dequeue() == [1] diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index a303d0bd25a3c..a4f30808d32e1 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -9,7 +9,7 @@ from vllm import _custom_ops as ops from vllm.distributed.device_communicators.custom_all_reduce_utils import ( gpu_p2p_access_check) -from vllm.distributed.parallel_state import is_in_the_same_node +from vllm.distributed.parallel_state import in_the_same_node_as from vllm.logger import init_logger from vllm.utils import cuda_device_count_stateless, is_full_nvlink @@ -64,7 +64,7 @@ def __init__(self, assert dist.get_backend(group) != dist.Backend.NCCL, ( "CustomAllreduce should be attached to a non-NCCL group.") - if not is_in_the_same_node(group): + if not all(in_the_same_node_as(group, source_rank=0)): # No need to initialize custom allreduce for multi-node case. logger.warning( "Custom allreduce is disabled because this process group" diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index bea205882d9d8..db0064951cd1b 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -1,16 +1,19 @@ import pickle import time from contextlib import contextmanager +from dataclasses import dataclass, field from multiprocessing import shared_memory -from typing import Optional +from typing import List, Optional from unittest.mock import patch import torch import torch.distributed as dist from torch.distributed import ProcessGroup +from zmq import PUB, REP, REQ, SUB, SUBSCRIBE, Context # type: ignore import vllm.envs as envs from vllm.logger import init_logger +from vllm.utils import get_ip, get_open_port VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL @@ -135,18 +138,183 @@ def get_metadata(self, current_idx: int): yield buf -class ShmRingBufferIO: +@dataclass +class Handle: + connect_ip: str + local_reader_ranks: List[int] = field(default_factory=list) - def __init__(self, buffer: ShmRingBuffer, reader_rank: int): - self.buffer = buffer - self.reader_rank = reader_rank - self._is_writer = self.reader_rank == -1 - self._is_reader = not self._is_writer - if self._is_reader: - assert 0 <= self.reader_rank < buffer.n_reader, \ - (f"Invalid reader rank {self.reader_rank} for buffer" - f" created with {buffer.n_reader} readers") - self.current_idx = 0 + buffer: Optional[ShmRingBuffer] = None + local_subscribe_port: Optional[int] = None + local_sync_port: Optional[int] = None + remote_subscribe_port: Optional[int] = None + remote_sync_port: Optional[int] = None + + +class MessageQueue: + + def __init__( + self, + n_reader, # number of all readers + n_local_reader, # number of local readers through shared memory + local_reader_ranks: Optional[List[int]] = None, + max_chunk_bytes: int = 1024 * 1024 * 10, + max_chunks: int = 10, + connect_ip: Optional[str] = None, + ): + if local_reader_ranks is None: + local_reader_ranks = list(range(n_local_reader)) + else: + assert len(local_reader_ranks) == n_local_reader + self.n_local_reader = n_local_reader + n_remote_reader = n_reader - n_local_reader + self.n_remote_reader = n_remote_reader + + if connect_ip is None: + connect_ip = get_ip() + + context = Context() + + if n_local_reader > 0: + # for local readers, we will: + # 1. create a shared memory ring buffer to communicate small data + # 2. create a publish-subscribe socket to communicate large data + self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes, + max_chunks) + + self.local_socket = context.socket(PUB) + local_subscribe_port = get_open_port() + self.local_socket.bind(f"tcp://*:{local_subscribe_port}") + + self.local_sync_socket = context.socket(REP) + local_sync_port = get_open_port() + self.local_sync_socket.bind(f"tcp://*:{local_sync_port}") + self.current_idx = 0 + + else: + self.buffer = None # type: ignore + local_subscribe_port = None + local_sync_port = None + self.local_socket = None + self.local_sync_socket = None + self.current_idx = -1 + + if n_remote_reader > 0: + # for remote readers, we will: + # create a publish-subscribe socket to communicate large data + self.remote_socket = context.socket(PUB) + remote_subscribe_port = get_open_port() + self.remote_socket.bind(f"tcp://*:{remote_subscribe_port}") + + self.remote_sync_socket = context.socket(REP) + remote_sync_port = get_open_port() + self.remote_sync_socket.bind(f"tcp://*:{remote_sync_port}") + else: + remote_subscribe_port = None + remote_sync_port = None + self.remote_socket = None + self.remote_sync_socket = None + + self._is_writer = True + self._is_local_reader = False + self.local_reader_rank = -1 + # rank does not matter for remote readers + self._is_remote_reader = False + + self.handle = Handle( + connect_ip=connect_ip, + local_reader_ranks=local_reader_ranks, + buffer=self.buffer, + local_subscribe_port=local_subscribe_port, + local_sync_port=local_sync_port, + remote_subscribe_port=remote_subscribe_port, + remote_sync_port=remote_sync_port, + ) + + def export_handle(self) -> Handle: + return self.handle + + @staticmethod + def create_from_handle(handle: Handle, rank) -> "MessageQueue": + self = MessageQueue.__new__(MessageQueue) + self.handle = handle + self._is_writer = False + + context = Context() + + if rank in handle.local_reader_ranks: + assert handle.buffer is not None + self.buffer = handle.buffer + self.current_idx = 0 + self.local_reader_rank = handle.local_reader_ranks.index(rank) + self._is_local_reader = True + self._is_remote_reader = False + + self.local_socket = context.socket(SUB) + self.local_socket.setsockopt_string(SUBSCRIBE, "") + self.local_socket.connect( + f"tcp://{handle.connect_ip}:{handle.local_subscribe_port}") + + self.local_sync_socket = context.socket(REQ) + self.local_sync_socket.connect( + f"tcp://{handle.connect_ip}:{handle.local_sync_port}") + + self.remote_socket = None + self.remote_sync_socket = None + else: + self.buffer = None # type: ignore + self.current_idx = -1 + self.local_reader_rank = -1 + self._is_local_reader = False + self._is_remote_reader = True + + self.local_socket = None + self.local_sync_socket = None + + self.remote_socket = context.socket(SUB) + self.remote_socket.setsockopt_string(SUBSCRIBE, "") + self.remote_socket.connect( + f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}") + + self.remote_sync_socket = context.socket(REQ) + self.remote_sync_socket.connect( + f"tcp://{handle.connect_ip}:{handle.remote_sync_port}") + + return self + + def wait_until_ready(self): + """This is a collective operation. All processes (including the + readers and the writer) should call this function. + """ + if self._is_writer: + # wait for all readers to connect + + # local readers + for i in range(self.n_local_reader): + recv = self.local_sync_socket.recv() + assert recv == b"READY" + self.local_sync_socket.send(b"READY") + if self.n_local_reader > 0: + self.local_socket.send(b"READY") + + # remote readers + for i in range(self.n_remote_reader): + recv = self.remote_sync_socket.recv() + assert recv == b"READY" + self.remote_sync_socket.send(b"READY") + if self.n_remote_reader > 0: + self.remote_socket.send(b"READY") + elif self._is_local_reader: + self.local_sync_socket.send(b"READY") + recv = self.local_sync_socket.recv() + assert recv == b"READY" + recv = self.local_socket.recv() + assert recv == b"READY" + elif self._is_remote_reader: + self.remote_sync_socket.send(b"READY") + recv = self.remote_sync_socket.recv() + assert recv == b"READY" + recv = self.remote_socket.recv() + assert recv == b"READY" @contextmanager def acquire_write(self): @@ -201,12 +369,12 @@ def acquire_write(self): @contextmanager def acquire_read(self): - assert self._is_reader, "Only readers can acquire read" + assert self._is_local_reader, "Only readers can acquire read" start_time = time.monotonic() n_warning = 1 while True: with self.buffer.get_metadata(self.current_idx) as metadata_buffer: - read_flag = metadata_buffer[self.reader_rank + 1] + read_flag = metadata_buffer[self.local_reader_rank + 1] written_flag = metadata_buffer[0] if not written_flag or read_flag: # this block is either @@ -236,7 +404,7 @@ def acquire_read(self): # caller has read from the buffer # set the read flag - metadata_buffer[self.reader_rank + 1] = 1 + metadata_buffer[self.local_reader_rank + 1] = 1 self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks break @@ -244,21 +412,36 @@ def acquire_read(self): def enqueue(self, obj): assert self._is_writer, "Only writers can enqueue" serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL) - if len(serialized_obj) > self.buffer.max_chunk_bytes: - raise RuntimeError( - f"{len(serialized_obj)=} larger than the allowed value " - f"{self.buffer.max_chunk_bytes}," - "Please increase the max_chunk_bytes parameter.") - with self.acquire_write() as buf: - buf[:len(serialized_obj)] = serialized_obj + if self.n_local_reader > 0: + if len(serialized_obj) >= self.buffer.max_chunk_bytes: + with self.acquire_write() as buf: + buf[0] = 1 # overflow + self.local_socket.send(serialized_obj) + else: + with self.acquire_write() as buf: + buf[0] = 0 # not overflow + buf[1:len(serialized_obj) + 1] = serialized_obj + if self.n_remote_reader > 0: + self.remote_socket.send(serialized_obj) def dequeue(self): - assert self._is_reader, "Only readers can dequeue" - with self.acquire_read() as buf: - # no need to know the size of serialized object - # pickle format itself contains the size information internally - # see https://docs.python.org/3/library/pickle.html - obj = pickle.loads(buf) + if self._is_local_reader: + overflow = False + with self.acquire_read() as buf: + overflow = buf[0] == 1 + if not overflow: + # no need to know the size of serialized object + # pickle format contains the size information internally + # see https://docs.python.org/3/library/pickle.html + obj = pickle.loads(buf[1:]) + if overflow: + recv = self.local_socket.recv() + obj = pickle.loads(recv) + elif self._is_remote_reader: + recv = self.remote_socket.recv() + obj = pickle.loads(recv) + else: + raise RuntimeError("Only readers can dequeue") return obj def broadcast_object(self, obj=None): @@ -272,24 +455,36 @@ def broadcast_object(self, obj=None): def create_from_process_group(pg: ProcessGroup, max_chunk_bytes, max_chunks, - writer_rank=0) -> "ShmRingBufferIO": + writer_rank=0) -> "MessageQueue": group_rank = dist.get_rank(pg) group_world_size = dist.get_world_size(pg) - ranks_inside_group = list(range(group_world_size)) global_ranks = dist.get_process_group_ranks(pg) + + from vllm.distributed.parallel_state import in_the_same_node_as + status = in_the_same_node_as(pg, source_rank=writer_rank) + same_node_ranks = [i for i, s in enumerate(status) if s] n_reader = group_world_size - 1 - buffer: ShmRingBuffer + n_local_reader = len(same_node_ranks) - 1 + local_reader_ranks = [i for i in same_node_ranks if i != writer_rank] + buffer_io: MessageQueue if group_rank == writer_rank: - buffer = ShmRingBuffer(n_reader, max_chunk_bytes, max_chunks) - dist.broadcast_object_list([buffer], + buffer_io = MessageQueue( + n_reader=n_reader, + n_local_reader=n_local_reader, + local_reader_ranks=local_reader_ranks, + max_chunk_bytes=max_chunk_bytes, + max_chunks=max_chunks, + ) + handle = buffer_io.export_handle() + dist.broadcast_object_list([handle], src=global_ranks[writer_rank], group=pg) - return ShmRingBufferIO(buffer, -1) else: recv = [None] dist.broadcast_object_list(recv, src=global_ranks[writer_rank], group=pg) - buffer = recv[0] # type: ignore - rest_ranks = [r for r in ranks_inside_group if r != writer_rank] - return ShmRingBufferIO(buffer, rest_ranks.index(group_rank)) + handle = recv[0] # type: ignore + buffer_io = MessageQueue.create_from_handle(handle, group_rank) + buffer_io.wait_until_ready() + return buffer_io diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 66ffe6e8a9fa9..128096c88a8b1 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -124,7 +124,7 @@ class GroupCoordinator: # communicators are only created for world size > 1 pynccl_comm: Optional[Any] # PyNccl communicator ca_comm: Optional[Any] # Custom allreduce communicator - shm_broadcaster: Optional[Any] # shared memory broadcaster + mq_broadcaster: Optional[Any] # shared memory broadcaster def __init__( self, @@ -133,6 +133,7 @@ def __init__( torch_distributed_backend: Union[str, Backend], use_pynccl: bool, use_custom_allreduce: bool, + use_message_queue_broadcaster: bool = False, ): self.rank = torch.distributed.get_rank() @@ -190,10 +191,10 @@ def __init__( self.ca_comm = None from vllm.distributed.device_communicators.shm_broadcast import ( - ShmRingBufferIO) - self.shm_broadcaster: Optional[ShmRingBufferIO] = None - if self.world_size > 1 and is_in_the_same_node(self.cpu_group): - self.shm_broadcaster = ShmRingBufferIO.create_from_process_group( + MessageQueue) + self.mq_broadcaster: Optional[MessageQueue] = None + if use_message_queue_broadcaster and self.world_size > 1: + self.mq_broadcaster = MessageQueue.create_from_process_group( self.cpu_group, 1 << 22, 6) @property @@ -377,9 +378,9 @@ def broadcast_object(self, obj: Optional[Any] = None, src: int = 0): # Bypass the function if we are using only 1 GPU. if self.world_size == 1: return obj - if self.shm_broadcaster is not None: - assert src == 0, "Shared memory broadcaster only supports src=0" - return self.shm_broadcaster.broadcast_object(obj) + if self.mq_broadcaster is not None: + assert src == 0, "Message queue broadcaster only supports src=0" + return self.mq_broadcaster.broadcast_object(obj) if self.rank_in_group == src: torch.distributed.broadcast_object_list([obj], src=self.ranks[src], @@ -696,8 +697,8 @@ def destroy(self): self.pynccl_comm = None if self.ca_comm is not None: self.ca_comm = None - if self.shm_broadcaster is not None: - self.shm_broadcaster = None + if self.mq_broadcaster is not None: + self.mq_broadcaster = None _WORLD: Optional[GroupCoordinator] = None @@ -720,10 +721,12 @@ def init_world_group(ranks: List[int], local_rank: int, def init_model_parallel_group( - group_ranks: List[List[int]], - local_rank: int, - backend: str, - use_custom_allreduce: Optional[bool] = None) -> GroupCoordinator: + group_ranks: List[List[int]], + local_rank: int, + backend: str, + use_custom_allreduce: Optional[bool] = None, + use_message_queue_broadcaster: bool = False, +) -> GroupCoordinator: if use_custom_allreduce is None: use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE return GroupCoordinator( @@ -732,6 +735,7 @@ def init_model_parallel_group( torch_distributed_backend=backend, use_pynccl=True, use_custom_allreduce=use_custom_allreduce, + use_message_queue_broadcaster=use_message_queue_broadcaster, ) @@ -880,8 +884,12 @@ def initialize_model_parallel( range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)) group_ranks.append(ranks) + + # message queue broadcaster is only used in tensor model parallel group _TP = init_model_parallel_group(group_ranks, - get_world_group().local_rank, backend) + get_world_group().local_rank, + backend, + use_message_queue_broadcaster=True) # Build the pipeline model-parallel groups. num_pipeline_model_parallel_groups: int = (world_size // @@ -993,15 +1001,15 @@ def destroy_distributed_environment(): torch.distributed.destroy_process_group() -def is_in_the_same_node(pg: ProcessGroup): +def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]: """ - This is a collective operation that checks if all processes in the group - are in the same node. It tests if all processes are attached to the same + This is a collective operation that returns if each rank is in the same node + as the source rank. It tests if processes are attached to the same memory system (shared access to shared memory). """ assert torch.distributed.get_backend( pg) != torch.distributed.Backend.NCCL, ( - "is_in_the_same_node should be tested with a non-NCCL group.") + "in_the_same_node_as should be tested with a non-NCCL group.") # local rank inside the group rank = torch.distributed.get_rank(group=pg) world_size = torch.distributed.get_world_size(group=pg) @@ -1017,19 +1025,19 @@ def is_in_the_same_node(pg: ProcessGroup): try: with contextlib.suppress(OSError): - if rank == 0: + if rank == source_rank: # create a shared memory segment shm = shared_memory.SharedMemory(create=True, size=128) shm.buf[:len(magic_message)] = magic_message torch.distributed.broadcast_object_list([shm.name], - src=ranks[0], + src=ranks[source_rank], group=pg) - is_in_the_same_node[0] = 1 + is_in_the_same_node[rank] = 1 else: # try to open the shared memory segment recv = [None] torch.distributed.broadcast_object_list(recv, - src=ranks[0], + src=ranks[source_rank], group=pg) name = recv[0] # fix to https://stackoverflow.com/q/62748654/9191338 @@ -1050,8 +1058,8 @@ def is_in_the_same_node(pg: ProcessGroup): # clean up the shared memory segment with contextlib.suppress(OSError): - if rank == 0 and shm: + if rank == source_rank and shm: shm.unlink() torch.distributed.all_reduce(is_in_the_same_node, group=pg) - return is_in_the_same_node.sum().item() == world_size + return [x == 1 for x in is_in_the_same_node.tolist()] From 5ed3505d827658fe4f71f30fecf93a66baabfe26 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 9 Jul 2024 19:30:56 -0700 Subject: [PATCH 07/26] [Bugfix][TPU] Add prompt adapter methods to TPUExecutor (#6279) --- vllm/executor/tpu_executor.py | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/vllm/executor/tpu_executor.py b/vllm/executor/tpu_executor.py index 6627ee6984ddb..d906a6cc39dd7 100644 --- a/vllm/executor/tpu_executor.py +++ b/vllm/executor/tpu_executor.py @@ -81,8 +81,7 @@ def initialize_cache( def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of available KV blocks by invoking the - underlying worker. - """ + underlying worker.""" return self.driver_worker.determine_num_available_blocks() def execute_model( @@ -93,16 +92,36 @@ def execute_model( return output def add_lora(self, lora_request: LoRARequest) -> bool: - raise NotImplementedError("LoRA is not implemented for TPU backend.") + raise NotImplementedError( + "LoRA is currently not supported by the TPU backend.") def remove_lora(self, lora_id: int) -> bool: - raise NotImplementedError("LoRA is not implemented for TPU backend.") + raise NotImplementedError( + "LoRA is currently not supported by the TPU backend.") def pin_lora(self, lora_id: int) -> bool: - raise NotImplementedError("LoRA is not implemented for TPU backend.") + raise NotImplementedError( + "LoRA is currently not supported by the TPU backend.") def list_loras(self) -> Set[int]: - raise NotImplementedError("LoRA is not implemented for TPU backend.") + raise NotImplementedError( + "LoRA is currently not supported by the TPU backend.") + + def add_prompt_adapter(self, prompt_adapter_request) -> bool: + raise NotImplementedError( + "Soft prompt is currently not supported by the TPU backend.") + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + raise NotImplementedError( + "Soft prompt is currently not supported by the TPU backend.") + + def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: + raise NotImplementedError( + "Soft prompt is currently not supported by the TPU backend.") + + def list_prompt_adapters(self) -> Set[int]: + raise NotImplementedError( + "Soft prompt is currently not supported by the TPU backend.") def check_health(self) -> None: # TPUExecutor will always be healthy as long as it's running. From 8a924d2248dedb620eb9a32ca5c9f97ab525aaf5 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 10 Jul 2024 14:55:34 +0800 Subject: [PATCH 08/26] [Doc] Guide for adding multi-modal plugins (#6205) --- docs/source/_templates/sections/header.html | 1 + .../multimodal/adding_multimodal_plugin.rst | 17 +++++++++++++ .../dev/multimodal/multimodal_index.rst | 24 ++++++++++++------- vllm/multimodal/__init__.py | 5 ++-- vllm/multimodal/base.py | 21 +++++++++------- vllm/multimodal/image.py | 1 + vllm/multimodal/registry.py | 18 ++++++++++---- 7 files changed, 64 insertions(+), 23 deletions(-) create mode 100644 docs/source/dev/multimodal/adding_multimodal_plugin.rst diff --git a/docs/source/_templates/sections/header.html b/docs/source/_templates/sections/header.html index cd5c4053e225f..7174431b10272 100644 --- a/docs/source/_templates/sections/header.html +++ b/docs/source/_templates/sections/header.html @@ -5,6 +5,7 @@ justify-content: center; align-items: center; font-size: 16px; + padding: 0 6px 0 6px; } .notification-bar p { margin: 0; diff --git a/docs/source/dev/multimodal/adding_multimodal_plugin.rst b/docs/source/dev/multimodal/adding_multimodal_plugin.rst new file mode 100644 index 0000000000000..b726138f840a3 --- /dev/null +++ b/docs/source/dev/multimodal/adding_multimodal_plugin.rst @@ -0,0 +1,17 @@ +.. _adding_multimodal_plugin: + +Adding a Multimodal Plugin +========================== + +This document teaches you how to add a new modality to vLLM. + +Each modality in vLLM is represented by a :class:`~vllm.multimodal.MultiModalPlugin` and registered to :data:`~vllm.multimodal.MULTIMODAL_REGISTRY`. +For vLLM to recognize a new modality type, you have to create a new plugin and then pass it to :meth:`~vllm.multimodal.MultiModalRegistry.register_plugin`. + +The remainder of this document details how to define custom :class:`~vllm.multimodal.MultiModalPlugin` s. + +.. note:: + This article is a work in progress. + +.. + TODO: Add more instructions on how to add new plugins once embeddings is in. diff --git a/docs/source/dev/multimodal/multimodal_index.rst b/docs/source/dev/multimodal/multimodal_index.rst index 39daf30a3338f..6713dcf08d9f0 100644 --- a/docs/source/dev/multimodal/multimodal_index.rst +++ b/docs/source/dev/multimodal/multimodal_index.rst @@ -7,17 +7,21 @@ Multi-Modality vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package. -Multi-modal input can be passed alongside text and token prompts to :ref:`supported models ` +Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models ` via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptStrictInputs`. -.. note:: - ``multi_modal_data`` can accept keys and values beyond the builtin ones, as long as a customized plugin is registered through - the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`. +Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities +by following :ref:`this guide `. -To implement a new multi-modal model in vLLM, please follow :ref:`this guide `. +Looking to add your own multi-modal model? Please follow the instructions listed :ref:`here `. -.. - TODO: Add more instructions on how to add new plugins once embeddings is in. +Guides +++++++ + +.. toctree:: + :maxdepth: 1 + + adding_multimodal_plugin Module Contents +++++++++++++++ @@ -36,10 +40,14 @@ Registry Base Classes ------------ -.. autoclass:: vllm.multimodal.MultiModalDataDict +.. autodata:: vllm.multimodal.BatchedTensors + +.. autoclass:: vllm.multimodal.MultiModalDataBuiltins :members: :show-inheritance: +.. autodata:: vllm.multimodal.MultiModalDataDict + .. autoclass:: vllm.multimodal.MultiModalInputs :members: :show-inheritance: diff --git a/vllm/multimodal/__init__.py b/vllm/multimodal/__init__.py index b6d930659a8c1..503dceab5b168 100644 --- a/vllm/multimodal/__init__.py +++ b/vllm/multimodal/__init__.py @@ -1,5 +1,5 @@ -from .base import (BatchedTensors, MultiModalDataDict, MultiModalInputs, - MultiModalPlugin) +from .base import (BatchedTensors, MultiModalDataBuiltins, MultiModalDataDict, + MultiModalInputs, MultiModalPlugin) from .registry import MultiModalRegistry MULTIMODAL_REGISTRY = MultiModalRegistry() @@ -13,6 +13,7 @@ __all__ = [ "BatchedTensors", + "MultiModalDataBuiltins", "MultiModalDataDict", "MultiModalInputs", "MultiModalPlugin", diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 0e31816a8e8ac..3ebc25c5930cf 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -43,9 +43,6 @@ def try_concat( *, device: torch.types.Device, ) -> BatchedTensors: - # Avoid initializing CUDA too early - import torch - unbatched_shape = tensors[0].shape[1:] for tensor in tensors: @@ -84,16 +81,21 @@ def batch( class MultiModalDataBuiltins(TypedDict, total=False): + """Modality types that are predefined by vLLM.""" + image: Image.Image + """The input image.""" MultiModalDataDict = Union[MultiModalDataBuiltins, Dict[str, Any]] """ A dictionary containing an item for each modality type to input. -The data belonging to each modality is converted into keyword arguments -to the model by the corresponding mapper. By default, the mapper of -the corresponding plugin with the same modality key is applied. +Note: + This dictionary also accepts modality keys defined outside + :class:`MultiModalDataBuiltins` as long as a customized plugin is registered + through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`. + Read more on that :ref:`here `. """ MultiModalInputMapper = Callable[[InputContext, object], MultiModalInputs] @@ -123,6 +125,9 @@ class MultiModalPlugin(ABC): process the same data differently). This registry is in turn used by :class:`~MultiModalRegistry` which acts at a higher level (i.e., the modality of the data). + + See also: + :ref:`adding_multimodal_plugin` """ def __init__(self) -> None: @@ -183,8 +188,8 @@ def wrapper(model_cls: N) -> N: def map_input(self, model_config: ModelConfig, data: object) -> MultiModalInputs: """ - Apply an input mapper to a data passed - to the model, transforming the data into a dictionary of model inputs. + Transform the data into a dictionary of model inputs using the + input mapper registered for that model. The model is identified by ``model_config``. diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index b6c73512350f3..3b37ce9149fb8 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -100,6 +100,7 @@ def repeat_and_pad_image_tokens( class ImagePlugin(MultiModalPlugin): + """Plugin for image data.""" def get_data_key(self) -> str: return "image" diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index e0716bbf15715..d8e1b68178acd 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -15,10 +15,8 @@ class MultiModalRegistry: """ - A registry to dispatch data processing - according to its modality and the target model. - - The registry handles both external and internal data input. + A registry that dispatches data processing to the + :class:`~vllm.multimodal.MultiModalPlugin` for each modality. """ DEFAULT_PLUGINS = (ImagePlugin(), ) @@ -30,6 +28,12 @@ def __init__( self._plugins = {p.get_data_key(): p for p in plugins} def register_plugin(self, plugin: MultiModalPlugin) -> None: + """ + Register a multi-modal plugin so it can be recognized by vLLM. + + See also: + :ref:`adding_multimodal_plugin` + """ data_type_key = plugin.get_data_key() if data_type_key in self._plugins: @@ -75,7 +79,11 @@ def map_input(self, model_config: ModelConfig, data: MultiModalDataDict) -> MultiModalInputs: """ Apply an input mapper to the data passed to the model. - + + The data belonging to each modality is passed to the corresponding + plugin which in turn converts the data into into keyword arguments + via the input mapper registered for that model. + See :meth:`MultiModalPlugin.map_input` for more details. """ merged_dict: Dict[str, torch.Tensor] = {} From e72ae80b06405ea92b703c8979f046d68e970c94 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 10 Jul 2024 06:03:16 -0700 Subject: [PATCH 09/26] [Bugfix] Support 2D input shape in MoE layer (#6287) --- vllm/model_executor/models/mixtral.py | 5 +++-- vllm/model_executor/models/qwen2_moe.py | 6 ++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 7f5e3b9699c91..e5bd58a9e97b0 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -88,12 +88,13 @@ def __init__(self, tp_size=tp_size) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - num_tokens, hidden_size = hidden_states.shape + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape hidden_states = hidden_states.view(-1, self.hidden_size) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) final_hidden_states = self.experts(hidden_states, router_logits) - return final_hidden_states.view(num_tokens, hidden_size) + return final_hidden_states.view(orig_shape) class MixtralAttention(nn.Module): diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index ccaa6f20893e0..7b18b5e04b275 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -126,7 +126,9 @@ def __init__( bias=False) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - num_tokens, hidden_dim = hidden_states.shape + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape + hidden_dim = hidden_states.shape[-1] hidden_states = hidden_states.view(-1, hidden_dim) shared_output = None if self.shared_expert is not None: @@ -145,7 +147,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) - return final_hidden_states.view(num_tokens, hidden_dim) + return final_hidden_states.view(orig_shape) class Qwen2MoeAttention(nn.Module): From c38eba304674fdf9da4d881e46f103440e22a153 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Wed, 10 Jul 2024 15:04:07 +0200 Subject: [PATCH 10/26] [Bugfix] MLPSpeculator: Use ParallelLMHead in tie_weights=False case. (#6303) Signed-off-by: Thomas Parnell --- vllm/model_executor/models/mlp_speculator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/mlp_speculator.py b/vllm/model_executor/models/mlp_speculator.py index 97f7ec74292bb..d3aec06a92fdb 100644 --- a/vllm/model_executor/models/mlp_speculator.py +++ b/vllm/model_executor/models/mlp_speculator.py @@ -110,7 +110,7 @@ def __init__(self, config: MLPSpeculatorConfig, **kwargs) -> None: ]) self.head = nn.ModuleList([ - nn.Linear(self.inner_dim, self.vocab_size, bias=False) + ParallelLMHead(self.vocab_size, self.inner_dim, bias=False) for _ in range(self.max_speculative_tokens) ]) self.ln = nn.ModuleList([ From b422d4961a3052c5b4bcfc3747a1ad55acfe7eb8 Mon Sep 17 00:00:00 2001 From: Benjamin Muskalla Date: Wed, 10 Jul 2024 16:15:55 +0200 Subject: [PATCH 11/26] [CI/Build] Enable mypy typing for remaining folders (#6268) --- .github/workflows/mypy.yaml | 18 ++++++++++-------- format.sh | 18 +++++++++--------- vllm/platforms/cuda.py | 5 ++--- 3 files changed, 21 insertions(+), 20 deletions(-) diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml index 62f0dbcd93eff..5780f09a646cb 100644 --- a/.github/workflows/mypy.yaml +++ b/.github/workflows/mypy.yaml @@ -32,20 +32,22 @@ jobs: pip install types-setuptools - name: Mypy run: | + mypy tests --config-file pyproject.toml + mypy vllm/*.py --config-file pyproject.toml mypy vllm/attention --config-file pyproject.toml mypy vllm/core --config-file pyproject.toml mypy vllm/distributed --config-file pyproject.toml + mypy vllm/engine --config-file pyproject.toml mypy vllm/entrypoints --config-file pyproject.toml mypy vllm/executor --config-file pyproject.toml + mypy vllm/inputs --config-file pyproject.toml + mypy vllm/logging --config-file pyproject.toml + mypy vllm/lora --config-file pyproject.toml + mypy vllm/model_executor --config-file pyproject.toml mypy vllm/multimodal --config-file pyproject.toml - mypy vllm/usage --config-file pyproject.toml - mypy vllm/*.py --config-file pyproject.toml + mypy vllm/platforms --config-file pyproject.toml + mypy vllm/spec_decode --config-file pyproject.toml mypy vllm/transformers_utils --config-file pyproject.toml - mypy vllm/engine --config-file pyproject.toml + mypy vllm/usage --config-file pyproject.toml mypy vllm/worker --config-file pyproject.toml - mypy vllm/spec_decode --config-file pyproject.toml - mypy vllm/model_executor --config-file pyproject.toml - mypy vllm/lora --config-file pyproject.toml - mypy vllm/logging --config-file pyproject.toml - mypy tests --config-file pyproject.toml diff --git a/format.sh b/format.sh index 5edc868f9f70c..5ad6d6f2938bb 100755 --- a/format.sh +++ b/format.sh @@ -96,23 +96,23 @@ echo 'vLLM yapf: Done' # Run mypy echo 'vLLM mypy:' +mypy tests --config-file pyproject.toml +mypy vllm/*.py --config-file pyproject.toml mypy vllm/attention --config-file pyproject.toml mypy vllm/core --config-file pyproject.toml mypy vllm/distributed --config-file pyproject.toml +mypy vllm/engine --config-file pyproject.toml mypy vllm/entrypoints --config-file pyproject.toml mypy vllm/executor --config-file pyproject.toml +mypy vllm/logging --config-file pyproject.toml +mypy vllm/lora --config-file pyproject.toml +mypy vllm/model_executor --config-file pyproject.toml mypy vllm/multimodal --config-file pyproject.toml -mypy vllm/usage --config-file pyproject.toml -mypy vllm/*.py --config-file pyproject.toml +mypy vllm/prompt_adapter --config-file pyproject.toml +mypy vllm/spec_decode --config-file pyproject.toml mypy vllm/transformers_utils --config-file pyproject.toml -mypy vllm/engine --config-file pyproject.toml +mypy vllm/usage --config-file pyproject.toml mypy vllm/worker --config-file pyproject.toml -mypy vllm/spec_decode --config-file pyproject.toml -mypy vllm/model_executor --config-file pyproject.toml -mypy vllm/lora --config-file pyproject.toml -mypy vllm/logging --config-file pyproject.toml -mypy vllm/prompt_adapter --config-file pyproject.toml -mypy tests --config-file pyproject.toml # If git diff returns a file that is in the skip list, the file may be checked anyway: diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 2d482010cf760..02ba227460e3f 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -34,11 +34,10 @@ def get_physical_device_capability(device_id: int = 0) -> Tuple[int, int]: def device_id_to_physical_device_id(device_id: int) -> int: if "CUDA_VISIBLE_DEVICES" in os.environ: device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",") - device_ids = [int(device_id) for device_id in device_ids] physical_device_id = device_ids[device_id] + return int(physical_device_id) else: - physical_device_id = device_id - return physical_device_id + return device_id class CudaPlatform(Platform): From 44cc76610d0b23ce5d609867f6dae7e033dee818 Mon Sep 17 00:00:00 2001 From: "sangjune.park" Date: Thu, 11 Jul 2024 02:03:32 +0900 Subject: [PATCH 12/26] [Bugfix] Fix OpenVINOExecutor abstractmethod error (#6296) Signed-off-by: sangjune.park --- vllm/executor/openvino_executor.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/vllm/executor/openvino_executor.py b/vllm/executor/openvino_executor.py index 697d698b4edf7..1ef37785b6d59 100644 --- a/vllm/executor/openvino_executor.py +++ b/vllm/executor/openvino_executor.py @@ -90,6 +90,22 @@ def pin_lora(self, lora_id: int) -> bool: def list_loras(self) -> Set[int]: return self.driver_worker.list_loras() + def add_prompt_adapter(self, prompt_adapter_request) -> bool: + raise NotImplementedError( + "Soft prompt is currently not supported by the OPENVINO backend.") + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + raise NotImplementedError( + "Soft prompt is currently not supported by the OPENVINO backend.") + + def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: + raise NotImplementedError( + "Soft prompt is currently not supported by the OPENVINO backend.") + + def list_prompt_adapters(self) -> Set[int]: + raise NotImplementedError( + "Soft prompt is currently not supported by the OPENVINO backend.") + def check_health(self) -> None: # OpenVINOExecutor will always be healthy as long as # it's running. From ae151d73be479e9c0caa2fdfc30b17f073018ef3 Mon Sep 17 00:00:00 2001 From: sroy745 <142070531+sroy745@users.noreply.github.com> Date: Wed, 10 Jul 2024 16:02:47 -0700 Subject: [PATCH 13/26] [Speculative Decoding] Enabling bonus token in speculative decoding for KV cache based models (#5765) --- tests/spec_decode/test_dynamic_spec_decode.py | 11 +- tests/spec_decode/test_multi_step_worker.py | 212 +++++++++++++++++- tests/spec_decode/test_ngram_worker.py | 9 +- tests/spec_decode/test_spec_decode_worker.py | 147 +++++++++++- vllm/sequence.py | 18 +- vllm/spec_decode/interfaces.py | 5 +- vllm/spec_decode/medusa_worker.py | 8 +- vllm/spec_decode/mlp_speculator_worker.py | 5 +- vllm/spec_decode/multi_step_worker.py | 206 ++++++++++++++--- vllm/spec_decode/ngram_worker.py | 12 +- vllm/spec_decode/proposer_worker_base.py | 9 +- .../spec_decode/smaller_tp_proposer_worker.py | 11 +- vllm/spec_decode/spec_decode_worker.py | 67 ++++-- vllm/spec_decode/top1_proposer.py | 5 +- 14 files changed, 645 insertions(+), 80 deletions(-) diff --git a/tests/spec_decode/test_dynamic_spec_decode.py b/tests/spec_decode/test_dynamic_spec_decode.py index 29ed96999cb4c..1f3219593f96b 100644 --- a/tests/spec_decode/test_dynamic_spec_decode.py +++ b/tests/spec_decode/test_dynamic_spec_decode.py @@ -70,14 +70,17 @@ def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int, if queue_size < disable_by_batch_size: # Should raise exception when executing the mocked draft model. with pytest.raises(ValueError, match=exception_secret): - proposer.get_spec_proposals(execute_model_req=ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list, - num_lookahead_slots=k), ) + proposer.get_spec_proposals( + execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k), + seq_ids_with_bonus_token_in_last_step=set()) else: # Should not execute the draft model because spec decode is disabled # for all requests. Accordingly, the proposal length should be 0. proposals = proposer.get_spec_proposals( execute_model_req=ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, - num_lookahead_slots=k), ) + num_lookahead_slots=k), + seq_ids_with_bonus_token_in_last_step=set()) assert proposals.proposal_lens.tolist() == [0] * batch_size diff --git a/tests/spec_decode/test_multi_step_worker.py b/tests/spec_decode/test_multi_step_worker.py index 7744b2640fe94..9832d4f267e8a 100644 --- a/tests/spec_decode/test_multi_step_worker.py +++ b/tests/spec_decode/test_multi_step_worker.py @@ -118,7 +118,8 @@ def test_same_output_for_single_step(): actual_output, _ = multi_step_worker.sampler_output( execute_model_req=ExecuteModelRequest( seq_group_metadata_list=multi_step_seq_group), - sample_len=num_steps) + sample_len=num_steps, + seq_ids_with_bonus_token_in_last_step=set()) assert len(actual_output) == num_steps actual_output = actual_output[0] @@ -210,7 +211,8 @@ def test_same_output_for_multi_step(): multi_step_output, _ = multi_step_worker.sampler_output( execute_model_req=ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list), - sample_len=num_steps) + sample_len=num_steps, + seq_ids_with_bonus_token_in_last_step=set()) # Run single-step repeatedly. zero_kv_cache(worker.cache_engine) @@ -277,6 +279,203 @@ def test_same_output_for_multi_step(): single_step_logprobs) +@torch.inference_mode() +def test_multi_step_with_batch_expansion_correct_output(): + """ + In this test we verify that the MultiStepWorker is able to handle bonus + tokens correctly. The test verifies that if a sequence has a + bonus token then the MultiStepWorker is able to expand the batch by adding + new sequences corresponding to the sequences with bonus tokens. The + expanded batch is then used for predicting the next tokens. + """ + seed = 100 + model_name = 'JackFram/llama-68m' + + block_size = 16 + num_gpu_blocks = 2048 // block_size + batch_size = 128 + multi_step_worker = create_worker( + MultiStepWorker, + model_name, + block_size, + num_gpu_blocks, + seed, + model_runner_cls=TP1DraftModelRunner, + ) + worker = create_worker( + Worker, + model_name, + block_size, + num_gpu_blocks, + seed, + ) + random.seed(seed) + prompts = [[0] for _ in range(batch_size)] + num_steps = 2 + final_prompt_lens = [(num_steps + 1) for prompt in prompts] + rand_seeds = list(random.randint(0, 100) for _ in range(num_steps)) + multi_step_worker.execute_model = patch_execute_model_with_seeds( + multi_step_worker, rand_seeds) + worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds) + # Create the test continuations + continuations = [[random.randint(0, 1000)] for _ in prompts] + seq_group_metadata_list = create_seq_group_metadata_from_prompts( + prompts, + num_gpu_blocks, + block_size, + continuations=continuations, + final_prompt_lens=final_prompt_lens) + + # Run single-step twice to generate 2 tokens. This + # will simulate the bonus token case with the second token + # being the bonus token. + zero_kv_cache(worker.cache_engine) + single_step_output: List[SamplerOutput] = [] + set_random_seed(seed) + for _ in range(num_steps): + seq_group_metadata_list = create_seq_group_metadata_from_prompts( + prompts, + num_gpu_blocks, + block_size, + continuations=continuations, + final_prompt_lens=final_prompt_lens) + single_step_output.extend( + worker.execute_model(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list))) + # Append output tokens to new sequence data. + for i, seq_group_output in enumerate(single_step_output[-1]): + continuations[i].append(seq_group_output.samples[0].output_token) + + # Create continuations for the MultiStepWorker. The continuations have + # 2 tokens in order to simulate the bonus token case. + multi_step_continuations = [] + for continuation in continuations: + multi_step_continuations.append(continuation[:2]) + seq_group_metadata_list = create_seq_group_metadata_from_prompts( + prompts, + num_gpu_blocks, + block_size, + continuations=multi_step_continuations, + final_prompt_lens=final_prompt_lens) + + # Run multi-step and verify that the third token prediction is accurate + # for all sequences. + zero_kv_cache(multi_step_worker.cache_engine) + all_seq_ids = {i for i in range(batch_size)} + multi_step_output, _ = multi_step_worker.sampler_output( + execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list), + sample_len=1, + seq_ids_with_bonus_token_in_last_step=all_seq_ids) + for index, output in enumerate(multi_step_output[-1].outputs): + assert (continuations[index][-1] == output.samples[0].output_token) + + +@torch.inference_mode() +def test_multi_step_with_batch_expansion_incorrect_output(): + """ + Tests the MultiStepWorker's ability to handle batch expansion with bonus + tokens in a negative case scenario. This test provides the MultiStepWorker + with a batch containing sequences with bonus tokens but specifies the + sequence IDs with bonus tokens incorrectly. The test verifies that the + MultiStepWorker generates correct tokens for the sequences where the + sequence ID is specified correctly and incorrect tokens for those where + the sequence ID is specified incorrectly. + """ + seed = 100 + model_name = 'JackFram/llama-68m' + + block_size = 16 + num_gpu_blocks = 2048 // block_size + batch_size = 128 + multi_step_worker = create_worker( + MultiStepWorker, + model_name, + block_size, + num_gpu_blocks, + seed, + model_runner_cls=TP1DraftModelRunner, + ) + worker = create_worker( + Worker, + model_name, + block_size, + num_gpu_blocks, + seed, + ) + random.seed(seed) + prompts = [[0] for _ in range(batch_size)] + num_steps = 2 + final_prompt_lens = [(num_steps + 1) for prompt in prompts] + rand_seeds = list(random.randint(0, 100) for _ in range(num_steps)) + multi_step_worker.execute_model = patch_execute_model_with_seeds( + multi_step_worker, rand_seeds) + worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds) + # Create the test continuations + continuations = [[random.randint(0, 1000)] for _ in prompts] + seq_group_metadata_list = create_seq_group_metadata_from_prompts( + prompts, + num_gpu_blocks, + block_size, + continuations=continuations, + final_prompt_lens=final_prompt_lens) + # Run single-step twice to generate 2 tokens. This + # will simulate the bonus token case with the second token + # being the bonus token. + zero_kv_cache(worker.cache_engine) + single_step_output: List[SamplerOutput] = [] + set_random_seed(seed) + for _ in range(num_steps): + seq_group_metadata_list = create_seq_group_metadata_from_prompts( + prompts, + num_gpu_blocks, + block_size, + continuations=continuations, + final_prompt_lens=final_prompt_lens) + single_step_output.extend( + worker.execute_model(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list))) + # Append output tokens to new sequence data. + for i, seq_group_output in enumerate(single_step_output[-1]): + continuations[i].append(seq_group_output.samples[0].output_token) + + # Create continuations for the MultiStepWorker. The continuations have + # 2 tokens in order to simulate the bonus token case. + multi_step_continuations = [] + for continuation in continuations: + multi_step_continuations.append(continuation[:2]) + seq_group_metadata_list = create_seq_group_metadata_from_prompts( + prompts, + num_gpu_blocks, + block_size, + continuations=multi_step_continuations, + final_prompt_lens=final_prompt_lens) + + # Run multi-step. In this run INCORRECTLY specify that only the odd number + # sequences have bonus tokens. Verify that with this setting the third token + # prediction is accurate only for the odd numbered sequences. Also verify + # that the prediction might be wrong for some of the even numbered + # sequences. + zero_kv_cache(multi_step_worker.cache_engine) + set_random_seed(seed) + odd_seq_ids = {i for i in range(batch_size) if i % 2 != 0} + multi_step_output, _ = multi_step_worker.sampler_output( + execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list), + sample_len=1, + seq_ids_with_bonus_token_in_last_step=odd_seq_ids) + num_mismatch = 0 + for index, output in enumerate(multi_step_output[-1].outputs): + if (index % 2) != 0: + assert (continuations[index][-1] == output.samples[0].output_token) + elif (continuations[index][-1] != output.samples[0].output_token): + num_mismatch += 1 + # The prediction is accurate for some of the sequences even without proper + # handling of the bonus tokens. Hence verify that the number of sequences + # for which there is a mismatch is > 0. + assert (num_mismatch > 0) + + @torch.inference_mode() def test_draft_proposals_full_speculation_len(): """Verify Top1Proposer correctly handles case where all sequences @@ -318,7 +517,8 @@ def test_draft_proposals_full_speculation_len(): proposals = proposer.get_spec_proposals( execute_model_req=ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, - num_lookahead_slots=k), ) + num_lookahead_slots=k), + seq_ids_with_bonus_token_in_last_step=set()) assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_probs) @@ -356,7 +556,8 @@ def test_draft_proposals_no_speculations(): proposals = proposer.get_spec_proposals( execute_model_req=ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, - num_lookahead_slots=k), ) + num_lookahead_slots=k), + seq_ids_with_bonus_token_in_last_step=set()) assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_probs) @@ -428,7 +629,8 @@ def test_draft_proposals_mixed_k(): proposals = proposer.get_spec_proposals( execute_model_req=ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, - num_lookahead_slots=k), ) + num_lookahead_slots=k), + seq_ids_with_bonus_token_in_last_step=set()) assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_probs) diff --git a/tests/spec_decode/test_ngram_worker.py b/tests/spec_decode/test_ngram_worker.py index b1537884f896e..3995f87898afb 100644 --- a/tests/spec_decode/test_ngram_worker.py +++ b/tests/spec_decode/test_ngram_worker.py @@ -53,7 +53,8 @@ def test_ngram_algo_correctness_for_single_no_match(): proposals = proposer.get_spec_proposals( execute_model_req=ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, - num_lookahead_slots=proposal_len), ) + num_lookahead_slots=proposal_len), + seq_ids_with_bonus_token_in_last_step=None) assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_probs) @@ -121,7 +122,8 @@ def test_ngram_algo_correctness_for_batches_not_match_all(): proposals = proposer.get_spec_proposals( execute_model_req=ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, - num_lookahead_slots=proposal_len), ) + num_lookahead_slots=proposal_len), + seq_ids_with_bonus_token_in_last_step=None) assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_probs) @@ -193,7 +195,8 @@ def test_ngram_algo_correctness_for_batches_match_all(): proposals = proposer.get_spec_proposals( execute_model_req=ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, - num_lookahead_slots=proposal_len), ) + num_lookahead_slots=proposal_len), + seq_ids_with_bonus_token_in_last_step=None) assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_probs) diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py index 527e7eddd7e33..0baac32042ef9 100644 --- a/tests/spec_decode/test_spec_decode_worker.py +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -1,6 +1,7 @@ import random +from collections import defaultdict from types import SimpleNamespace -from typing import Dict, List +from typing import Dict, List, Set from unittest.mock import MagicMock import pytest @@ -377,8 +378,10 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool, set_random_seed(1) - worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler, - metrics_collector) + worker = SpecDecodeWorker(draft_worker, + target_worker, + spec_decode_sampler, + metrics_collector=metrics_collector) worker.init_device() proposal_token_ids = torch.randint(low=0, @@ -554,7 +557,6 @@ def test_init_device(acceptance_sampler_method: str): worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler, metrics_collector) - worker.init_device() draft_worker.init_device.assert_called_once() @@ -645,3 +647,140 @@ def test_split_num_cache_blocks_evenly(available_gpu_blocks: int, assert (num_blocks * target_cache_block_size_bytes) + ( num_blocks * draft_kv_size_bytes) <= (available_gpu_blocks * target_cache_block_size_bytes) + + +@torch.inference_mode() +def test_populate_seq_ids_with_bonus_tokens(): + """ + Verify that a call to _create_output_sampler_list correctly updates + seq_with_bonus_token_in_last_step. + + seq_with_bonus_token_in_last_step is an internal data structure in + SpecDecodeWorker that tracks the sequence IDs which are assigned bonus + tokens by the target model in their last forward pass. This state is + maintained only for models relying on the KV cache, such as those using + the MultiStepWorker. + """ + batch_size = 10 + k = 5 + vocab_size = 10000 + num_sequences_with_bonus_tokens = 5 + target_worker = mock_worker(vocab_size=vocab_size, use_spec=False) + metrics_collector = MagicMock(spec=AsyncMetricsCollector) + target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)] + target_worker.device = 'cuda' + + set_random_seed(1) + draft_worker = mock_worker(cls=MultiStepWorker) + draft_worker.device = 'cuda' + # The sequence_ids attached to each sequence in the batch. + # The sequence at index i has seq_id assigned_seq_ids[i] + assigned_seq_ids = list(range(batch_size)) + seq_group_metadata_list, _, _ = create_batch(batch_size, + k, + seq_ids=assigned_seq_ids, + prev_output_token_len=10) + target_token_logprobs = torch.rand(batch_size, (k + 1), + vocab_size, + dtype=torch.float32, + device='cuda') + accepted_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, (k + 1)), + dtype=torch.int64, + device='cuda') + expected_request_id_seq_ids_mapping: Dict[str, Set[int]] = defaultdict(set) + for seq_group_metadata in seq_group_metadata_list: + for seq_id in seq_group_metadata.seq_data: + expected_request_id_seq_ids_mapping[ + seq_group_metadata.request_id].add(seq_id) + # Generate a random sample of sequence indexes with bonus tokens + seq_indexes_with_bonus_tokens = random.sample( + range(batch_size), num_sequences_with_bonus_tokens) + # Create a mask that is True for indices in seq_indexes_with_bonus_tokens + mask = torch.ones(batch_size, dtype=torch.bool, device='cuda') + mask[seq_indexes_with_bonus_tokens] = False + # Set the last token ID to -1 for all indices not in + # seq_indexes_with_bonus_tokens to indicate the lack of bonus token in + # those indices. + accepted_token_ids[mask, -1:] = -1 + worker = SpecDecodeWorker(draft_worker, + target_worker, + mock_spec_decode_sampler("rejection_sampler"), + metrics_collector=metrics_collector) + # Initialize _seq_with_bonus_token_in_last_step with a set of sequence IDs. + # This set includes all sequence IDs in the batch as well as an additional + # `num_extra_sequence_ids` sequence IDs. Note that the sequence IDs are in + # the range [0, batch_size + num_extra_sequence_ids). + num_extra_sequence_ids = 10 + worker._seq_with_bonus_token_in_last_step = set( + range(batch_size + num_extra_sequence_ids)) + worker._create_output_sampler_list( + seq_group_metadata_list=seq_group_metadata_list, + accepted_token_ids=accepted_token_ids, + target_logprobs=target_token_logprobs, + k=k) + # Verify that _seq_with_bonus_token_in_last_step contains the following: + # 1. Sequence IDs that were already present in + # _seq_with_bonus_token_in_last_step but were not part of the current + # batch are retained. + # 2. Of the sequence IDs present in the current batch, only those with a + # bonus token are retained in _seq_with_bonus_token_in_last_step. + # Sequence IDs that are present in the current batch but do not have + # bonus tokens are removed from _seq_with_bonus_token_in_last_step. + expected_seq_ids_with_bonus_tokens = \ + set([assigned_seq_ids[i] for i in seq_indexes_with_bonus_tokens]) + additional_sequence_ids = \ + set(range(batch_size, batch_size + num_extra_sequence_ids)) + assert worker._seq_with_bonus_token_in_last_step == \ + expected_seq_ids_with_bonus_tokens.union(additional_sequence_ids) + assert worker._request_id_seq_id_mapping == \ + expected_request_id_seq_ids_mapping + + +@torch.inference_mode() +def test_handle_finished_requests(): + """ + Test to verify that finished request IDs are appropriately processed to + update the internal state of the SpecDecodeWorker. + + This test initializes the SpecDecodeWorker with mock data, marks certain + requests as finished, and ensures that the corresponding sequence IDs are + correctly removed from the internal mappings. + """ + batch_size = 32 + k = 3 + draft_worker = mock_worker(cls=MultiStepWorker) + target_worker = mock_worker() + metrics_collector = MagicMock(spec=AsyncMetricsCollector) + worker = SpecDecodeWorker(draft_worker, target_worker, + mock_spec_decode_sampler("rejection_sampler"), + metrics_collector) + # Initialize the request_id_seq_id_mapping mapping dict with a few fake + # request ids and corresponding sequence ids. + worker._request_id_seq_id_mapping = \ + {'request-1': {1,2,3}, 'request-2': {4,5,6,7}, + 'request-3': {8,9}, 'request-4': {10,11}} + # Initialize seq_with_bonus_token_in_last_step with a few fake + # sequence ids. + worker._seq_with_bonus_token_in_last_step = {1, 4, 5, 8, 9, 10} + exception_secret = 'artificial stop' + draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret) + + seq_group_metadata_list, _, _ = create_batch(batch_size, k) + # Mark requests with ids request-1 and request-3 as finished. + execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k, + finished_requests_ids=['request-1', 'request-3']) + + with pytest.raises(ValueError, match=exception_secret): + worker.execute_model(execute_model_req=execute_model_req) + # Verify that request-1 and request-3 are removed from + # request_id_seq_id_mapping + assert worker._request_id_seq_id_mapping == \ + {'request-2': {4,5,6,7}, 'request-4': {10,11}} + # Verify that all sequence ids corresponding to 'request-1' + # and 'request-3' are removed from seq_with_bonus_token_in_last_step. + assert worker._seq_with_bonus_token_in_last_step == \ + {4,5,10} diff --git a/vllm/sequence.py b/vllm/sequence.py index a3f998b94d795..1cebf68d463db 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -3,8 +3,9 @@ import enum import math from abc import ABC, abstractmethod +from collections import defaultdict from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union import torch @@ -916,6 +917,21 @@ def get_all_seq_ids( return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data] +def get_all_seq_ids_and_request_ids( + seq_group_metadata_list: List[SequenceGroupMetadata] +) -> Tuple[List[int], Dict[str, Set[int]]]: + """Given a list of SequenceGroupMetadata, create a list of all + sequence ids. + """ + seq_ids: List[int] = [] + request_id_seq_ids_mapping: Dict[str, Set[int]] = defaultdict(set) + for sg in seq_group_metadata_list: + for seq_id in sg.seq_data: + seq_ids.append(seq_id) + request_id_seq_ids_mapping[sg.request_id].add(seq_id) + return seq_ids, request_id_seq_ids_mapping + + class HiddenStates: """Hidden states corresponding to in-progress sequences. Used in speculative decoding to pass hidden states from diff --git a/vllm/spec_decode/interfaces.py b/vllm/spec_decode/interfaces.py index d236fc0f2cb6b..d109d8edc1b0b 100644 --- a/vllm/spec_decode/interfaces.py +++ b/vllm/spec_decode/interfaces.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Optional +from typing import Optional, Set import torch @@ -62,6 +62,9 @@ class SpeculativeProposer(ABC): def get_spec_proposals( self, execute_model_req: ExecuteModelRequest, + # If set, this contains all sequence IDs that were assigned + # bonus tokens in their last forward pass. + seq_ids_with_bonus_token_in_last_step: Set[int], ) -> SpeculativeProposals: raise NotImplementedError diff --git a/vllm/spec_decode/medusa_worker.py b/vllm/spec_decode/medusa_worker.py index b72740fc3961c..041ce41e91d05 100644 --- a/vllm/spec_decode/medusa_worker.py +++ b/vllm/spec_decode/medusa_worker.py @@ -1,5 +1,5 @@ import weakref -from typing import List, Optional, Tuple +from typing import List, Optional, Set, Tuple import torch @@ -40,6 +40,8 @@ def sampler_output( self, execute_model_req: ExecuteModelRequest, sample_len: int, + # Unused parameter. + seq_ids_with_bonus_token_in_last_step: Set[int], ) -> Tuple[List[SamplerOutput], bool]: """Run the model forward pass to generate sample_len future tokens. Returns the list of sampler output, one per layer, along with indicator @@ -97,12 +99,14 @@ def _prepare_input_tensors( def get_spec_proposals( self, execute_model_req: ExecuteModelRequest, + seq_ids_with_bonus_token_in_last_step: Set[int], ) -> SpeculativeProposals: """Produce speculations given an input batch of sequences. The number of speculative tokens per sequence is determined by max_proposal_len. """ - return self._proposer.get_spec_proposals(execute_model_req) + return self._proposer.get_spec_proposals( + execute_model_req, seq_ids_with_bonus_token_in_last_step) def _raise_if_unsupported( self, diff --git a/vllm/spec_decode/mlp_speculator_worker.py b/vllm/spec_decode/mlp_speculator_worker.py index 6c1c8da57d188..308573348d443 100644 --- a/vllm/spec_decode/mlp_speculator_worker.py +++ b/vllm/spec_decode/mlp_speculator_worker.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple +from typing import List, Optional, Set, Tuple import torch @@ -20,6 +20,9 @@ def sampler_output( self, execute_model_req: ExecuteModelRequest, sample_len: int, + # Unused parameter. MLPSpeculatorWorker does not use the KV Cache and + # therefore does not need this parameter. + seq_ids_with_bonus_token_in_last_step: Set[int], ) -> Tuple[List[SamplerOutput], bool]: """Run the model forward pass to generate sample_len future tokens. Returns the list of sampler output, one per layer, along with indicator diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index c1a02e1d32e85..09a77f9e870fb 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -1,6 +1,6 @@ import copy import weakref -from typing import Dict, List, Tuple +from typing import Dict, List, Set, Tuple import torch @@ -51,6 +51,7 @@ def sampler_output( self, execute_model_req: ExecuteModelRequest, sample_len: int, + seq_ids_with_bonus_token_in_last_step: Set[int], ) -> Tuple[List[SamplerOutput], bool]: """Run the model forward pass sample_len times. Returns the list of sampler output, one per model forward pass, along with indicator of @@ -60,44 +61,142 @@ def sampler_output( For multi step worker, this indicator shall be True. """ self._raise_if_unsupported(execute_model_req) - - # Shallow copy input data so modifications (such as appending tokens) - # do not cause side-effects. - copied_seq_group_metadata_list = self._shallow_copy_inputs( - execute_model_req.seq_group_metadata_list) - copied_execute_model_req = execute_model_req.clone( - copied_seq_group_metadata_list) - + # Expand the batch for sequences with a bonus token. + # Perform a forward pass on the expanded batch and filter the + # response to retain only the original sequences' responses. + expanded_request, indices_of_seq_with_bonus_tokens =\ + self._expand_execute_model_request( + execute_model_req, seq_ids_with_bonus_token_in_last_step) # Run model sample_len times. model_outputs: List[SamplerOutput] = [] if isinstance(self.model_runner, TP1DraftModelRunner): - copied_execute_model_req.num_steps = sample_len + expanded_request.num_steps = sample_len model_outputs = self.execute_model( - execute_model_req=copied_execute_model_req) + execute_model_req=expanded_request) else: # TODO: Remove this branch once DraftModelRunner supports TP>1. for _ in range(sample_len): model_output: List[SamplerOutput] = super().execute_model( - execute_model_req=copied_execute_model_req) + execute_model_req=expanded_request) assert (len(model_output) == 1 ), "composing multistep workers not supported" model_output = model_output[0] - self._append_new_tokens(model_output, - copied_seq_group_metadata_list) + self._append_new_tokens( + model_output, expanded_request.seq_group_metadata_list) model_outputs.append(model_output) - return model_outputs, True + filtered_model_outputs = self._filter_model_output( + model_outputs, indices_of_seq_with_bonus_tokens) + return filtered_model_outputs, True + + @staticmethod + def _expand_execute_model_request( + execute_model_req: ExecuteModelRequest, + seq_with_bonus_token_in_last_step: set, + ) -> Tuple[ExecuteModelRequest, List[int]]: + """ + Expands the execute model request based on sequences with bonus + tokens. + + For each sequence with a bonus token, this method creates a new + sequence without the bonus token and adds it to the execute model + request. The original sequence groups are also retained. The indices + of the original sequence groups are returned for further processing. + + Args: + execute_model_req (ExecuteModelRequest): The original execute + model request. + seq_with_bonus_token_in_last_step (set): Set of sequence IDs that + contain bonus tokens. + + Returns: + Tuple[ExecuteModelRequest, List[int]]: The updated execute model + request with expanded sequences and a list of indices corresponding + to the original sequence groups. + """ + updated_seq_group_metadata_list: List[SequenceGroupMetadata] = [] + updated_execute_model_req = execute_model_req.clone( + updated_seq_group_metadata_list) + indices_of_original_sequence_groups = [] + for seq_group in execute_model_req.seq_group_metadata_list: + seq_group_has_bonus_tokens = False + for seq_id, _ in seq_group.seq_data.items(): + # Identify sequences with bonus tokens in the sequence group. + if seq_id in seq_with_bonus_token_in_last_step: + seq_group_has_bonus_tokens = True + break + if seq_group_has_bonus_tokens: + #Create new sequences without the last bonus token. These new + # sequence have the same sequence id as the original sequence. + # We create a new sequence group and add them there. + updated_seq_group_without_bonus_token = \ + MultiStepWorker._copy_seq_metadata_excluding_last_token( + seq_group, seq_with_bonus_token_in_last_step) + updated_seq_group_metadata_list.append( + updated_seq_group_without_bonus_token) + # Add the original sequence group. + updated_seq_group_metadata_list.append( + MultiStepWorker._shallow_copy_seq_group_metadata(seq_group)) + # Record the index of the original sequence group. + indices_of_original_sequence_groups.append( + len(updated_seq_group_metadata_list) - 1) + + updated_execute_model_req.seq_group_metadata_list =\ + updated_seq_group_metadata_list + return updated_execute_model_req, indices_of_original_sequence_groups + + @staticmethod + def _filter_model_output( + expanded_batch_outputs: List[SamplerOutput], + output_indices_to_retain: List[int]) -> List[SamplerOutput]: + """ + Filters the model output to include only the specified sequence + outputs. This method contracts the expanded batch output from the + model to retain the outputs of only those sequences indicated by the + provided indices. + + Args: + expanded_batch_output (List[SamplerOutput]): The expanded output + batch from the model. + output_indices_to_retain (List[int]): Indices of the model outputs + to retain. + + Returns: + List[SamplerOutput]: A list containing the filtered model + outputs for the specified indices. + """ + return [ + SamplerOutput( + outputs=[ + expanded_batch_output.outputs[i] + for i in output_indices_to_retain + ], + sampled_token_probs=( + expanded_batch_output. + sampled_token_probs[output_indices_to_retain] + if expanded_batch_output.sampled_token_probs is not None + else None), + logprobs=( + expanded_batch_output.logprobs[output_indices_to_retain] + if expanded_batch_output.logprobs is not None else None), + sampled_token_ids=(expanded_batch_output. + sampled_token_ids[output_indices_to_retain] + if expanded_batch_output.sampled_token_ids + is not None else None)) + for expanded_batch_output in expanded_batch_outputs + ] def get_spec_proposals( self, execute_model_req: ExecuteModelRequest, + seq_ids_with_bonus_token_in_last_step: set, ) -> SpeculativeProposals: """Produce speculations given an input batch of sequences. The number of speculative tokens per sequence is determined by max_proposal_len. """ - - return self._proposer.get_spec_proposals(execute_model_req) + return self._proposer.get_spec_proposals( + execute_model_req, seq_ids_with_bonus_token_in_last_step) @staticmethod def _append_new_tokens( @@ -123,9 +222,8 @@ def _append_new_tokens( seq.update_num_computed_tokens(1) @staticmethod - def _shallow_copy_inputs( - seq_group_metadata_list: List[SequenceGroupMetadata] - ) -> List[SequenceGroupMetadata]: + def _shallow_copy_seq_group_metadata( + seq_group_metadata: SequenceGroupMetadata, ) -> SequenceGroupMetadata: """Copy input data structures to remove side-effects when input data structures are shared with other modules. @@ -133,26 +231,62 @@ def _shallow_copy_inputs( The alternative is deep-copying (or other form of deep copy); this has performance downsides. """ - - # Shallow-copy the list of SequenceGroupMetadata. This allows us to + # Shallow-copy the SequenceGroupMetadata. This allows us to # append tokens and change is_prompt without external side-effects. - new_seq_group_metadata_list: List[SequenceGroupMetadata] = [] + # We must shallow-copy seq_group_metadata as is_prompt could change. + new_seq_group_metadata = copy.copy(seq_group_metadata) - for old_seq_group_metadata in seq_group_metadata_list: - # We must shallow-copy seq_group_metadata as is_prompt could change. - seq_group_metadata = copy.copy(old_seq_group_metadata) - new_seq_group_metadata_list.append(seq_group_metadata) - - # We must shallow-copy seq_data as we will append token ids - new_seq_data: Dict[int, SequenceData] = {} - for seq_id, old_seq_data in seq_group_metadata.seq_data.items(): - new_seq_data[seq_id] = copy.copy(old_seq_data) - new_seq_data[ - seq_id].output_token_ids = old_seq_data.output_token_ids[:] + # We must shallow-copy seq_data as we will append token ids + new_seq_data: Dict[int, SequenceData] = {} + for seq_id, old_seq_data in seq_group_metadata.seq_data.items(): + new_seq_data[seq_id] = copy.copy(old_seq_data) + new_seq_data[seq_id].output_token_ids =\ + old_seq_data.output_token_ids[:] - seq_group_metadata.seq_data = new_seq_data + new_seq_group_metadata.seq_data = new_seq_data + return new_seq_group_metadata - return new_seq_group_metadata_list + @staticmethod + def _copy_seq_metadata_excluding_last_token( + seq_group_metadata: SequenceGroupMetadata, + seq_ids_to_copy: Set[int], + ) -> SequenceGroupMetadata: + """ + Creates a shallow copy of the given SequenceGroupMetadata, retaining + only the sequence IDs specified in seq_ids_to_copy. For each of these + sequence IDs, all output_token_ids except the last one are copied. + Sequence IDs not in seq_ids_to_copy are excluded from the copy. + + Parameters: + seq_group_metadata (SequenceGroupMetadata): The original sequence + group metadata. + seq_ids_to_copy (Set[int]): The set of sequence IDs to include in the + copy. + + Returns: + SequenceGroupMetadata: A shallow copy of the sequence group metadata + with the specified modifications. + """ + # Shallow-copy the SequenceGroupMetadata. + new_seq_group_metadata = copy.copy(seq_group_metadata) + # Shallow-copy seq_data and modify the output_token_ids. + new_seq_data: Dict[int, SequenceData] = {} + for seq_id, old_seq_data in seq_group_metadata.seq_data.items(): + if (seq_id in seq_ids_to_copy): + new_seq_data[seq_id] = copy.copy(old_seq_data) + # Copy all the output token ids except the last. + # Also reduce num_computed_tokens by 1 since we are not + # including the last output token. + # NOTE: num_computed_tokens is not directly used by the + # speculative decoding workers, as it is only relevant for + # chunked prefill, which is disabled for speculative decoding. + # However, to maintain consistency in num_computed_tokens, + # we update it here. + new_seq_data[seq_id].output_token_ids =\ + old_seq_data.output_token_ids[:-1] + new_seq_data[seq_id].update_num_computed_tokens(-1) + new_seq_group_metadata.seq_data = new_seq_data + return new_seq_group_metadata def _assert_enough_kv_space( self, seq_group_metadata_list: List[SequenceGroupMetadata], diff --git a/vllm/spec_decode/ngram_worker.py b/vllm/spec_decode/ngram_worker.py index 23a3e1649914b..07991df52e655 100644 --- a/vllm/spec_decode/ngram_worker.py +++ b/vllm/spec_decode/ngram_worker.py @@ -1,5 +1,5 @@ import weakref -from typing import List, Optional, Tuple +from typing import List, Optional, Set, Tuple import torch @@ -48,6 +48,9 @@ def sampler_output( self, execute_model_req: ExecuteModelRequest, sample_len: int, + # Unused parameter. NGramWorker does not use the KV Cache and + # therefore does not need this parameter. + seq_ids_with_bonus_token_in_last_step: Set[int], ) -> Tuple[Optional[List[Optional[SamplerOutput]]], bool]: """NGram match algo to pick proposal candidate. Returns the list of sampler output, one per SequenceGroupMetadata. @@ -133,12 +136,15 @@ def sampler_output( def get_spec_proposals( self, execute_model_req: ExecuteModelRequest, + # Unused parameter. NGramWorker does not use the KV Cache and + # therefore does not need this parameter. + seq_ids_with_bonus_token_in_last_step: Set[int], ) -> SpeculativeProposals: """Produce speculations given an input batch of sequences. The number of speculative tokens per sequence is determined by max_proposal_len. """ - - return self._proposer.get_spec_proposals(execute_model_req) + return self._proposer.get_spec_proposals( + execute_model_req, seq_ids_with_bonus_token_in_last_step) def _raise_if_unsupported( self, diff --git a/vllm/spec_decode/proposer_worker_base.py b/vllm/spec_decode/proposer_worker_base.py index b691659fb292b..fffa557121e17 100644 --- a/vllm/spec_decode/proposer_worker_base.py +++ b/vllm/spec_decode/proposer_worker_base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import List, Optional, Tuple +from typing import List, Optional, Set, Tuple from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.spec_decode.interfaces import SpeculativeProposer @@ -14,6 +14,13 @@ def sampler_output( self, execute_model_req: ExecuteModelRequest, sample_len: int, + # A set containing all sequence IDs that were assigned bonus tokens + # in their last forward pass. This set is used to backfill the KV cache + # with the key-value pairs of the penultimate token in the sequences. + # This parameter is only used by the MultiStepWorker, which relies on + # the KV cache for token generation. It is not used by workers that + # do not utilize the KV cache. + seq_ids_with_bonus_token_in_last_step: Set[int] ) -> Tuple[Optional[List[SamplerOutput]], bool]: raise NotImplementedError diff --git a/vllm/spec_decode/smaller_tp_proposer_worker.py b/vllm/spec_decode/smaller_tp_proposer_worker.py index b78e4489513f7..0dbb924d25400 100644 --- a/vllm/spec_decode/smaller_tp_proposer_worker.py +++ b/vllm/spec_decode/smaller_tp_proposer_worker.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple +from typing import List, Optional, Set, Tuple import torch @@ -110,13 +110,17 @@ def sampler_output( self, execute_model_req: ExecuteModelRequest, sample_len: int, + seq_ids_with_bonus_token_in_last_step: Set[int], ) -> Tuple[List[SamplerOutput], bool]: # Do not check _is_dummy, as it's always called by get_spec_proposals - return self._worker.sampler_output(execute_model_req, sample_len) + return self._worker.sampler_output( + execute_model_req, sample_len, + seq_ids_with_bonus_token_in_last_step) def get_spec_proposals( self, execute_model_req: ExecuteModelRequest, + seq_ids_with_bonus_token_in_last_step: Set[int], ) -> SpeculativeProposals: """Produce speculations given an input batch of sequences. The number of speculative tokens per sequence is determined by max_proposal_len. @@ -125,7 +129,8 @@ def get_spec_proposals( return SpeculativeProposals(None, None, None) with self._patch_tensor_parallel_group(): - return self._worker.get_spec_proposals(execute_model_req) + return self._worker.get_spec_proposals( + execute_model_req, seq_ids_with_bonus_token_in_last_step) def execute_model( self, diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 60a7dab68b7fd..3c8e3dee46831 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -1,5 +1,6 @@ +from collections import defaultdict from functools import cached_property -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Set, Tuple import torch @@ -13,7 +14,7 @@ TypicalAcceptanceSampler) from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest, HiddenStates, SamplerOutput, SequenceGroupMetadata, - get_all_seq_ids) + get_all_seq_ids_and_request_ids) from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner from vllm.spec_decode.interfaces import (SpeculativeProposals, @@ -112,11 +113,7 @@ def create_worker( draft_worker_kwargs.pop("ngram_prompt_lookup_max")) ngram_prompt_lookup_min = ( draft_worker_kwargs.pop("ngram_prompt_lookup_min")) - - disable_bonus_tokens = True - if ngram_prompt_lookup_max > 0: - disable_bonus_tokens = False proposer_worker = NGramWorker(**draft_worker_kwargs) proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min, ngram_prompt_lookup_max) @@ -128,11 +125,9 @@ def create_worker( if draft_worker_kwargs[ "model_config"].hf_config.model_type == "mlp_speculator": - disable_bonus_tokens = False proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs) elif draft_worker_kwargs[ "model_config"].hf_config.model_type == "medusa": - disable_bonus_tokens = False proposer_worker = MedusaWorker(**draft_worker_kwargs) else: if draft_tp == 1: @@ -149,10 +144,10 @@ def create_worker( spec_decode_sampler: SpecDecodeBaseSampler = None if draft_token_acceptance_method == "rejection_sampler": spec_decode_sampler = RejectionSampler( - disable_bonus_tokens=disable_bonus_tokens, ) + disable_bonus_tokens=False, ) elif draft_token_acceptance_method == "typical_acceptance_sampler": spec_decode_sampler = TypicalAcceptanceSampler( - disable_bonus_tokens=disable_bonus_tokens, + disable_bonus_tokens=False, posterior_threshold=\ typical_acceptance_sampler_posterior_threshold, posterior_alpha=typical_acceptance_sampler_posterior_alpha, @@ -200,6 +195,15 @@ def __init__( self._metrics = AsyncMetricsCollector( self.spec_decode_sampler ) if metrics_collector is None else metrics_collector + # Tracks the sequence IDs that received a bonus token ID in + # their last forward pass. Needed only if KV cache is being + # used for token generation such as in the case of MultiStepWorker. + self._seq_with_bonus_token_in_last_step: Set[int] = set() + # Tracks the currently active request ids and the sequence IDs + # corresponding to them + self._request_id_seq_id_mapping: Dict[str, Set[int]] = defaultdict(set) + # Tracks if the proposer worker uses the KV cache or not. + self.probs_dtype = self.spec_decode_sampler.probs_dtype self.token_id_dtype = self.spec_decode_sampler.token_id_dtype # Lazy initiazliation. @@ -307,6 +311,7 @@ def execute_model( broadcast_tensor_dict({}, src=0) return [] + self._track_finished_requests(execute_model_req) disable_all_speculation = self._should_disable_all_speculation( execute_model_req) num_lookahead_slots = execute_model_req.num_lookahead_slots @@ -453,7 +458,8 @@ def _run_speculative_decoding_step( self.previous_hidden_states = None # Generate proposals using draft worker. - proposals = self.proposer_worker.get_spec_proposals(execute_model_req) + proposals = self.proposer_worker.get_spec_proposals( + execute_model_req, self._seq_with_bonus_token_in_last_step) proposal_scores = self.scorer.score_proposals( execute_model_req, @@ -585,7 +591,9 @@ def _create_output_sampler_list( # Get the sequence ids and num_logprobs (sampling parameter) in the # batch. - seq_ids = get_all_seq_ids(seq_group_metadata_list) + seq_ids, request_ids_seq_ids_mapping = get_all_seq_ids_and_request_ids( + seq_group_metadata_list) + num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list) # Serialize all tensors to CPU Python lists. @@ -608,7 +616,6 @@ def _create_output_sampler_list( for sequence_index in range(batch_size): # Each sequence may have a different num_logprobs; retrieve it. num_logprobs = num_logprobs_per_seq[sequence_index] - step_output_token_ids.append( create_sequence_group_output( token_id=accepted_token_ids_by_step[step_index] @@ -623,18 +630,48 @@ def _create_output_sampler_list( topk_logprobs=topk_logprobs_by_step[step_index] [sequence_index][:num_logprobs], )) - sampler_output_list.append( SamplerOutput(outputs=step_output_token_ids)) + # Populate the data structures needed to keep track of sequences with + # bonus tokens. + self._track_sequences_with_bonus_tokens(seq_ids, + request_ids_seq_ids_mapping, + accepted_token_ids_by_step) maybe_rejsample_metrics = ( self._metrics.maybe_collect_rejsample_metrics(k)) if maybe_rejsample_metrics is not None: sampler_output_list[ 0].spec_decode_worker_metrics = maybe_rejsample_metrics - return sampler_output_list + def _track_finished_requests(self, execute_model_req: ExecuteModelRequest): + """ + Removes the finished requests and their associated sequence ids from + internal book keeping data structures. + """ + for finished_request in execute_model_req.finished_requests_ids: + for seq_id in self._request_id_seq_id_mapping[finished_request]: + self._seq_with_bonus_token_in_last_step.discard(seq_id) + del self._request_id_seq_id_mapping[finished_request] + + def _track_sequences_with_bonus_tokens( + self, seq_ids: List[int], + request_ids_seq_ids_mapping: Dict[str, Set[int]], + accepted_token_ids_by_step: List[List[int]]): + """ + Updates the internal data structures which keep track of sequences + which have been assigned bonus tokens in their last forward pass. + """ + for seq_index, seq_id in enumerate(seq_ids): + last_token_id = accepted_token_ids_by_step[-1][seq_index] + if last_token_id == -1: + self._seq_with_bonus_token_in_last_step.discard(seq_id) + else: + self._seq_with_bonus_token_in_last_step.add(seq_id) + for request_id, sequences in request_ids_seq_ids_mapping.items(): + self._request_id_seq_id_mapping[request_id].update(sequences) + @cached_property def _vocab_size(self) -> int: """Get the vocab size of the model and make sure it's consistent between diff --git a/vllm/spec_decode/top1_proposer.py b/vllm/spec_decode/top1_proposer.py index d3e280e6843b8..7b34b5d34208b 100644 --- a/vllm/spec_decode/top1_proposer.py +++ b/vllm/spec_decode/top1_proposer.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple +from typing import List, Optional, Set, Tuple import torch @@ -42,6 +42,7 @@ def __init__( def get_spec_proposals( self, execute_model_req: ExecuteModelRequest, + seq_ids_with_bonus_token_in_last_step: Set[int], ) -> SpeculativeProposals: """Get speculative proposals given the input batch. @@ -76,6 +77,8 @@ def get_spec_proposals( maybe_sampler_output, transposed = self._worker.sampler_output( execute_model_req=nonzero_execute_model_req, sample_len=proposal_len, + seq_ids_with_bonus_token_in_last_step=\ + seq_ids_with_bonus_token_in_last_step, ) ( proposal_lens, From 997df46a32f3b2c2debe3e17730895cef0d94d2a Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 10 Jul 2024 16:39:02 -0700 Subject: [PATCH 14/26] [Bugfix][Neuron] Fix soft prompt method error in NeuronExecutor (#6313) --- vllm/executor/neuron_executor.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py index 53107dada9962..6b2cb3e2403f2 100644 --- a/vllm/executor/neuron_executor.py +++ b/vllm/executor/neuron_executor.py @@ -70,6 +70,22 @@ def pin_lora(self, lora_id: int) -> bool: def list_loras(self) -> Set[int]: return self.driver_worker.list_loras() + def add_prompt_adapter(self, prompt_adapter_request) -> bool: + raise NotImplementedError( + "Soft prompt is currently not supported by the Neuron backend.") + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + raise NotImplementedError( + "Soft prompt is currently not supported by the Neuron backend.") + + def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: + raise NotImplementedError( + "Soft prompt is currently not supported by the Neuron backend.") + + def list_prompt_adapters(self) -> Set[int]: + raise NotImplementedError( + "Soft prompt is currently not supported by the Neuron backend.") + def check_health(self) -> None: # NeuronExecutor will always be healthy as long as # it's running. From 99ded1e1c4dc00baa77beae74602ebafe4921176 Mon Sep 17 00:00:00 2001 From: daquexian Date: Thu, 11 Jul 2024 01:05:26 +0100 Subject: [PATCH 15/26] [Doc] Remove comments incorrectly copied from another project (#6286) --- vllm/model_executor/layers/linear.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 1dda5d3740a8b..7100fe1422ff4 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -743,7 +743,6 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): param_data.copy_(loaded_weight) def forward(self, input_): - # Set up backprop all-reduce. if self.input_is_parallel: input_parallel = input_ else: From 439c84581aaf45917c6f77805a3511f1efc052bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jie=20Fu=20=28=E5=82=85=E6=9D=B0=29?= Date: Thu, 11 Jul 2024 12:15:29 +0800 Subject: [PATCH 16/26] [Doc] Update description of vLLM support for CPUs (#6003) --- README.md | 2 +- docs/source/getting_started/cpu-installation.rst | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 3e0da945d9be8..cced85f17e257 100644 --- a/README.md +++ b/README.md @@ -59,7 +59,7 @@ vLLM is flexible and easy to use with: - Tensor parallelism support for distributed inference - Streaming outputs - OpenAI-compatible API server -- Support NVIDIA GPUs, AMD GPUs, Intel CPUs and GPUs +- Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs - (Experimental) Prefix caching support - (Experimental) Multi-lora support diff --git a/docs/source/getting_started/cpu-installation.rst b/docs/source/getting_started/cpu-installation.rst index a9544e8a59a3d..1c97515dbecd9 100644 --- a/docs/source/getting_started/cpu-installation.rst +++ b/docs/source/getting_started/cpu-installation.rst @@ -20,7 +20,7 @@ Requirements * OS: Linux * Compiler: gcc/g++>=12.3.0 (optional, recommended) -* Instruction set architecture (ISA) requirement: AVX512 is required. +* Instruction set architecture (ISA) requirement: AVX512 (optional, recommended) .. _cpu_backend_quick_start_dockerfile: From fc17110bbef4e78703abffac51133a2fb71e9f79 Mon Sep 17 00:00:00 2001 From: Lim Xiang Yang Date: Thu, 11 Jul 2024 12:37:11 +0800 Subject: [PATCH 17/26] [BugFix]: set outlines pkg version (#6262) --- requirements-common.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-common.txt b/requirements-common.txt index e874c4af49d66..b750f9a1b0571 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -18,7 +18,7 @@ prometheus_client >= 0.18.0 prometheus-fastapi-instrumentator >= 7.0.0 tiktoken >= 0.6.0 # Required for DBRX tokenizer lm-format-enforcer == 0.10.1 -outlines >= 0.0.43 # Requires torch >= 2.1.0 +outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0 typing_extensions filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 pyzmq From c4774eb8418864390341d35103aa747fc411b59c Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Thu, 11 Jul 2024 00:04:05 -0700 Subject: [PATCH 18/26] [Bugfix] Fix snapshot download in serving benchmark (#6318) --- benchmarks/backend_request_func.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index fe29c67086158..fbab547d094fe 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -390,17 +390,17 @@ def remove_prefix(text: str, prefix: str) -> str: return text -def get_model(pretrained_model_name_or_path: str): +def get_model(pretrained_model_name_or_path: str) -> str: if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true': from modelscope import snapshot_download - else: - from huggingface_hub import snapshot_download - - model_path = snapshot_download( - model_id=pretrained_model_name_or_path, - local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, - ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"]) - return model_path + + model_path = snapshot_download( + model_id=pretrained_model_name_or_path, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"]) + + return model_path + return pretrained_model_name_or_path def get_tokenizer( From 3963a5335bb4106f2ecd1139527e3568d2151933 Mon Sep 17 00:00:00 2001 From: aniaan Date: Thu, 11 Jul 2024 17:39:07 +0800 Subject: [PATCH 19/26] [Misc] refactor(config): clean up unused code (#6320) --- vllm/config.py | 6 ++---- vllm/worker/xpu_model_runner.py | 3 --- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 68ca81a2ec4fe..d333a042fe5af 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -138,12 +138,10 @@ def __init__( self.quantization = quantization self.quantization_param_path = quantization_param_path self.enforce_eager = enforce_eager - self.max_context_len_to_capture = max_context_len_to_capture - if self.max_context_len_to_capture is not None: + if max_context_len_to_capture is not None: raise ValueError("`max_context_len_to_capture` is deprecated. " "Use `max_seq_len_to_capture` instead.") - self.max_seq_len_to_capture = (max_seq_len_to_capture - or max_context_len_to_capture) + self.max_seq_len_to_capture = max_seq_len_to_capture self.max_logprobs = max_logprobs self.disable_sliding_window = disable_sliding_window self.skip_tokenizer_init = skip_tokenizer_init diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index e03f24fdfc41a..876abb3bf94d1 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -109,9 +109,6 @@ def __init__( self.kv_cache_dtype = kv_cache_dtype self.block_size = cache_config.block_size - self.max_context_len_to_capture = ( - self.model_config.max_context_len_to_capture - if self.model_config is not None else 0) self.attn_backend = get_attn_backend( self.model_config.get_num_attention_heads(self.parallel_config), From 546b101fa05043feb470513a778c31114ea3aa05 Mon Sep 17 00:00:00 2001 From: pushan <62173185+pushan01@users.noreply.github.com> Date: Thu, 11 Jul 2024 21:46:31 +0800 Subject: [PATCH 20/26] [BugFix]: fix engine timeout due to request abort (#6255) Signed-off-by: yatta zhang Signed-off-by: zhangyuntao.dev Co-authored-by: zhangyuntao.dev --- vllm/engine/async_llm_engine.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 9b4ef48b0e47e..f3c8d69e4efe9 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -553,11 +553,13 @@ async def engine_step(self, virtual_engine: int) -> bool: request_outputs = await self.engine.step_async(virtual_engine) # Put the outputs into the corresponding streams. + finished = True for request_output in request_outputs: self._request_tracker.process_request_output( request_output, verbose=self.log_requests) + finished = finished and request_output.finished - return len(request_outputs) > 0 + return not finished async def _engine_abort(self, request_ids: Iterable[str]): if self.engine_use_ray: From 8a1415cf776b2b902f6429ecfc325877b57cbefe Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Thu, 11 Jul 2024 16:05:59 +0200 Subject: [PATCH 21/26] [Bugfix] GPTBigCodeForCausalLM: Remove lm_head from supported_lora_modules. (#6326) Signed-off-by: Thomas Parnell Co-authored-by: Travis Johnson --- vllm/model_executor/models/gpt_bigcode.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index cc42413d53f4c..fc4e13bbb0e68 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -235,7 +235,7 @@ def forward( class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA): packed_modules_mapping = {"c_attn": ["c_attn"]} - supported_lora_modules = ["c_fc", "c_proj", "wte", "lm_head", "c_attn"] + supported_lora_modules = ["c_fc", "c_proj", "wte", "c_attn"] embedding_modules = { "wte": "input_embeddings", From 55f692b46ef35ed4a9e199dfe60a9eefe800e4b0 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Thu, 11 Jul 2024 17:40:20 +0300 Subject: [PATCH 22/26] [BugFix] get_and_reset only when scheduler outputs are not empty (#6266) --- vllm/engine/async_llm_engine.py | 4 ++-- vllm/engine/llm_engine.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index f3c8d69e4efe9..93bf8793dae33 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -225,11 +225,11 @@ async def step_async( """ seq_group_metadata_list, scheduler_outputs = self.scheduler[ virtual_engine].schedule() - finished_requests_ids = self.scheduler[ - virtual_engine].get_and_reset_finished_requests_ids() if not scheduler_outputs.is_empty(): # Execute the model. + finished_requests_ids = self.scheduler[ + virtual_engine].get_and_reset_finished_requests_ids() execute_model_req = ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index b476594fc73f6..d354218cf16ea 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -871,10 +871,10 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: "as performance will be severely degraded otherwise.") seq_group_metadata_list, scheduler_outputs = self.scheduler[ 0].schedule() - finished_requests_ids = self.scheduler[ - 0].get_and_reset_finished_requests_ids() if not scheduler_outputs.is_empty(): + finished_requests_ids = self.scheduler[ + 0].get_and_reset_finished_requests_ids() execute_model_req = ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, From b675069d7486129dbed7847f420b7a927691f16b Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Thu, 11 Jul 2024 11:40:11 -0400 Subject: [PATCH 23/26] [ Misc ] Refactor Marlin Python Utilities (#6082) Co-authored-by: Robert Shaw --- benchmarks/kernels/benchmark_marlin.py | 10 +- tests/kernels/test_marlin_gemm.py | 46 +- tests/quantization/test_compressed_tensors.py | 23 +- .../schemes/compressed_tensors_wNa16.py | 151 +++--- .../model_executor/layers/quantization/fp8.py | 2 +- .../layers/quantization/gptq_marlin.py | 263 +++-------- .../quantization/utils/marlin_24_perms.py | 60 --- .../layers/quantization/utils/marlin_perms.py | 60 --- .../layers/quantization/utils/marlin_utils.py | 439 ++++++------------ .../quantization/utils/marlin_utils_fp8.py | 109 +++++ .../quantization/utils/marlin_utils_test.py | 120 +++++ .../{format_24.py => marlin_utils_test_24.py} | 163 ++++++- 12 files changed, 704 insertions(+), 742 deletions(-) delete mode 100644 vllm/model_executor/layers/quantization/utils/marlin_24_perms.py delete mode 100644 vllm/model_executor/layers/quantization/utils/marlin_perms.py create mode 100644 vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py create mode 100644 vllm/model_executor/layers/quantization/utils/marlin_utils_test.py rename vllm/model_executor/layers/quantization/utils/{format_24.py => marlin_utils_test_24.py} (71%) diff --git a/benchmarks/kernels/benchmark_marlin.py b/benchmarks/kernels/benchmark_marlin.py index 261f5829631ee..3da4cecd7eeff 100644 --- a/benchmarks/kernels/benchmark_marlin.py +++ b/benchmarks/kernels/benchmark_marlin.py @@ -5,14 +5,16 @@ from benchmark_shapes import WEIGHT_SHAPES from vllm import _custom_ops as ops -from vllm.model_executor.layers.quantization.gptq_marlin import ( - GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, - GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS) from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - MarlinWorkspace, marlin_24_quantize, marlin_quantize) + GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, + GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS) +from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( + MarlinWorkspace, marlin_quantize) +from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import ( + marlin_24_quantize) from vllm.model_executor.layers.quantization.utils.quant_utils import ( gptq_pack, quantize_weights, sort_weights) from vllm.utils import FlexibleArgumentParser diff --git a/tests/kernels/test_marlin_gemm.py b/tests/kernels/test_marlin_gemm.py index 92ddcb209b690..3bd6680cf8134 100644 --- a/tests/kernels/test_marlin_gemm.py +++ b/tests/kernels/test_marlin_gemm.py @@ -5,19 +5,21 @@ import pytest import torch +from tests.quantization.utils import is_quant_method_supported from vllm import _custom_ops as ops -from vllm.model_executor.layers.quantization.gptq_marlin import ( - GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, - GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS, - marlin_permute_scales) from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS) -from vllm.model_executor.layers.quantization.utils.marlin_perms import ( - marlin_perm) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - MarlinWorkspace, compute_max_diff, is_marlin_supported, marlin_24_quantize, - marlin_quantize, marlin_weights, pack_fp8_to_int32) + GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, + GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS, + marlin_permute_scales) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( + pack_fp8_to_int32) +from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( + MarlinWorkspace, get_weight_perm, marlin_quantize, marlin_weights) +from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import ( + marlin_24_quantize) from vllm.model_executor.layers.quantization.utils.quant_utils import ( gptq_pack, quantize_weights, sort_weights) @@ -42,11 +44,16 @@ DTYPES = [torch.float16, torch.bfloat16] +def compute_max_diff(output, output_ref): + return torch.mean(torch.abs(output - output_ref)) / torch.mean( + torch.abs(output_ref)) + + def rand_data(shape, dtype=torch.float16): return torch.randn(shape, dtype=dtype, device="cuda") -@pytest.mark.skipif(not is_marlin_supported(), +@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), reason="Marlin is not supported on this GPU type.") @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) @@ -93,8 +100,8 @@ def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order, q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) # Pack to Marlin format - marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, num_bits, - marlin_perm[num_bits]) + weight_perm = get_weight_perm(num_bits) + marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm) # Run Marlin repack GPU kernel marlin_q_w_2 = ops.gptq_marlin_repack( @@ -109,7 +116,7 @@ def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order, assert torch.allclose(marlin_q_w_1, marlin_q_w_2) -@pytest.mark.skipif(not is_marlin_supported(), +@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), reason="Marlin is not supported on this GPU type.") @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) @@ -174,7 +181,7 @@ def test_marlin_gemm( assert max_diff < 0.04 -@pytest.mark.skipif(not is_marlin_supported(), +@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), reason="Marlin is not supported on this GPU type.") @pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_24_N_CHUNKS) @@ -222,7 +229,7 @@ def test_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size, mnk_factors): assert max_diff < 0.04 -@pytest.mark.skipif(not is_marlin_supported(), +@pytest.mark.skipif(not is_quant_method_supported("fp8"), reason="Marlin is not supported on this GPU type.") @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) @@ -268,13 +275,10 @@ def test_fp8_marlin_gemm( # expand it to channelwise scales = weight_scale.repeat(1, size_n).to(a_input.dtype).to("cuda") # Permute scales - marlin_scales = marlin_permute_scales( - s=scales, - size_k=size_k, - size_n=size_n, - group_size=-1, - num_bits=8, - ) + marlin_scales = marlin_permute_scales(s=scales, + size_k=size_k, + size_n=size_n, + group_size=-1) workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL) diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 96223a247657b..888e20e51a842 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -6,7 +6,6 @@ import pytest import torch -from vllm import SamplingParams from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 CompressedTensorsLinearMethod, CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, @@ -57,12 +56,14 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args): assert qkv_proj.weight_scale.dtype is torch.float32 assert qkv_proj.input_scale.dtype is torch.float32 + output = llm.generate_greedy("Hello my name is", max_tokens=20) + assert output + def test_compressed_tensors_no_enforce_eager(vllm_runner): model_path = "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change" with vllm_runner(model_path) as llm: - sampling_params = SamplingParams() - output = llm.generate("Hello world!", sampling_params=sampling_params) + output = llm.generate_greedy("Hello my name is", max_tokens=20) assert output @@ -84,13 +85,16 @@ def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner, model_args): assert qkv_proj.scheme.strategy == strategy assert qkv_proj.weight.dtype is torch.int8 + output = llm.generate_greedy("Hello my name is", max_tokens=20) + assert output + @pytest.mark.parametrize( "wNa16_args", [("nm-testing/tinyllama-oneshot-w4a16-channel-v2", "channel", None, 8), ("nm-testing/tinyllama-oneshot-w4a16-group128-v2", "group", 128, 8), ("nm-testing/tinyllama-oneshot-w8a16-per-channel", "channel", None, 4)]) -def test_compressed_tensors_w4a16(vllm_runner, wNa16_args): +def test_compressed_tensors_wNa16(vllm_runner, wNa16_args): model, strategy, group, pack_factor = wNa16_args with vllm_runner(model) as llm: model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 @@ -101,12 +105,15 @@ def test_compressed_tensors_w4a16(vllm_runner, wNa16_args): assert isinstance(qkv_proj.scheme, CompressedTensorsWNA16) assert qkv_proj.scheme.strategy == strategy - assert qkv_proj.scheme.group_size == group + assert qkv_proj.scheme.group_size == (-1 if group is None else group) assert qkv_proj.weight_packed.dtype is torch.int32 assert qkv_proj.weight_scale.dtype is torch.float16 assert qkv_proj.weight_packed.pack_factor == pack_factor + output = llm.generate_greedy("Hello my name is", max_tokens=20) + assert output + def test_compressed_tensors_w4a16_marlin24(vllm_runner): model_path = "nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t" @@ -120,8 +127,7 @@ def test_compressed_tensors_w4a16_marlin24(vllm_runner): assert isinstance(qkv_proj.scheme, CompressedTensorsW4A16Sparse24) assert qkv_proj.weight_packed.dtype is torch.int32 - sampling_params = SamplingParams() - output = llm.generate("Hello world!", sampling_params=sampling_params) + output = llm.generate_greedy("Hello my name is", max_tokens=20) assert output @@ -142,6 +148,5 @@ def test_compressed_tensors_fp8(vllm_runner): assert len(qkv_proj.input_scale.shape) == 0 assert len(qkv_proj.weight_scale.shape) == 0 - sampling_params = SamplingParams() - output = llm.generate("Hello world!", sampling_params=sampling_params) + output = llm.generate_greedy("Hello my name is", max_tokens=20) assert output diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py index 2243260053ef5..ed9fa73c175a1 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -6,9 +6,10 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) -from vllm.model_executor.layers.quantization.gptq_marlin import ( - GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, GPTQMarlinState, - marlin_permute_scales) +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + apply_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace, + marlin_permute_scales, replace_tensor, verify_marlin_supported, + verify_marlin_supports_shape) from vllm.model_executor.utils import set_weight_attrs __all__ = ["CompressedTensorsWNA16"] @@ -22,29 +23,40 @@ def __init__(self, num_bits: int, group_size: Optional[int] = None): self.num_bits = num_bits + self.pack_factor = 32 // self.num_bits self.strategy = strategy - self.group_size = group_size - if self.strategy == "group" and self.group_size is None: - raise ValueError( - "group_size must be given when using strategy group") + self.group_size: int + if group_size is None: + if self.strategy != "channel": + raise ValueError( + "Marlin kernels require group quantization or " + "channelwise quantization, but found no group " + "size and strategy is not channelwise.") + self.group_size = -1 + else: + self.group_size = group_size - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - pass + # Verify supported on platform. + verify_marlin_supported(num_bits=self.num_bits, + group_size=self.group_size, + is_sym=True) def create_weights(self, layer: torch.nn.Module, input_size: int, output_partition_sizes: List[int], input_size_per_partition: int, params_dtype: torch.dtype, weight_loader: Callable, **kwargs): - - pack_factor = 32 // self.num_bits output_size_per_partition = sum(output_partition_sizes) - if self.group_size is not None: - group_size = self.group_size - else: - group_size = input_size + # If group_size is -1, we are in channelwise case. + group_size = input_size if self.group_size == -1 else self.group_size + + verify_marlin_supports_shape( + output_size_per_partition=output_size_per_partition, + input_size_per_partition=input_size_per_partition, + input_size=input_size, + group_size=group_size) weight_scale_dim = None scales_and_zp_size = input_size // group_size @@ -57,7 +69,7 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, weight = Parameter( torch.empty( output_size_per_partition, - input_size_per_partition // pack_factor, + input_size_per_partition // self.pack_factor, dtype=torch.int32, ), requires_grad=False, @@ -68,7 +80,7 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, "input_dim": 1, "output_dim": 0, "packed_dim": 1, - "pack_factor": pack_factor, + "pack_factor": self.pack_factor, "weight_loader": weight_loader }) layer.register_parameter("weight_packed", weight) @@ -103,73 +115,48 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition - layer.input_size = input_size - layer.marlin_state = GPTQMarlinState.REPACK - layer.is_k_full = True layer.group_size = group_size - max_workspace_size = ( - output_size_per_partition // - GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL - - workspace = torch.zeros(max_workspace_size, - dtype=torch.int, - requires_grad=False) - layer.workspace = workspace + # Checkpoints are serialized in compressed-tensors format, which is + # different from marlin format. Handle repacking here. + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + device = layer.weight_packed.device + + # Allocate marlin workspace. + layer.workspace = marlin_make_workspace( + layer.output_size_per_partition, device) + + # Act-order not supported in compressed-tensors yet, so set to empty. + layer.g_idx = marlin_make_empty_g_idx(device) + layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) + + # Repack weights from compressed-tensors format to marlin format. + marlin_qweight = ops.gptq_marlin_repack( + layer.weight_packed.t().contiguous(), + perm=layer.g_idx_sort_indices, + size_k=layer.input_size_per_partition, + size_n=layer.output_size_per_partition, + num_bits=self.num_bits) + replace_tensor(layer, "weight_packed", marlin_qweight) + + # Permute scales from compressed-tensors format to marlin format. + marlin_scales = marlin_permute_scales( + layer.weight_scale.squeeze().t().contiguous(), + size_k=layer.input_size_per_partition, + size_n=layer.output_size_per_partition, + group_size=layer.group_size) + replace_tensor(layer, "weight_scale", marlin_scales) def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): - reshaped_x = x.reshape(-1, x.shape[-1]) - - size_m = reshaped_x.shape[0] - part_size_n = layer.output_size_per_partition - part_size_k = layer.input_size_per_partition - - out_shape = x.shape[:-1] + (part_size_n, ) - - if layer.marlin_state == GPTQMarlinState.REPACK: - layer.marlin_state = GPTQMarlinState.READY - - # Newly generated tensors need to replace existing tensors that are - # already registered as parameters by vLLM (and won't be freed) - def replace_tensor(name, new_t): - # It is important to use resize_() here since it ensures - # the same buffer is reused - getattr(layer, name).resize_(new_t.shape) - getattr(layer, name).copy_(new_t) - del new_t - - cur_device = layer.weight_packed.device - - # Reset g_idx related tensors - layer.g_idx = Parameter(torch.empty(0, - dtype=torch.int, - device=cur_device), - requires_grad=False) - layer.g_idx_sort_indices = Parameter(torch.empty( - 0, dtype=torch.int, device=cur_device), - requires_grad=False) - - # Repack weights - marlin_qweight = ops.gptq_marlin_repack( - layer.weight_packed.t().contiguous(), layer.g_idx_sort_indices, - part_size_k, part_size_n, self.num_bits) - - replace_tensor("weight_packed", marlin_qweight) - - # Permute scales - scales_size_k = part_size_k - scales_size_n = part_size_n - - marlin_scales = marlin_permute_scales( - layer.weight_scale.squeeze().t().contiguous(), scales_size_k, - scales_size_n, layer.group_size, self.num_bits) - replace_tensor("weight_scale", marlin_scales) - - output = ops.gptq_marlin_gemm(reshaped_x, layer.weight_packed, - layer.weight_scale, layer.g_idx, - layer.g_idx_sort_indices, - layer.workspace, self.num_bits, size_m, - part_size_n, part_size_k, - layer.is_k_full) - return output.reshape(out_shape) + return apply_marlin_linear( + input=x, + weight=layer.weight_packed, + weight_scale=layer.weight_scale, + g_idx=layer.g_idx, + g_idx_sort_indices=layer.g_idx_sort_indices, + workspace=layer.workspace, + num_bits=self.num_bits, + output_size_per_partition=layer.output_size_per_partition, + input_size_per_partition=layer.input_size_per_partition, + is_k_full=True) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 8dba9019f94cf..0c2d2bd3fabe5 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -11,7 +11,7 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) -from vllm.model_executor.layers.quantization.utils.marlin_utils import ( +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( all_close_1d, apply_fp8_linear, create_per_tensor_scale_param, diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 6b971f73d45bf..7b808f5216d57 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -1,5 +1,3 @@ -import enum -from enum import Enum from typing import Any, Dict, List, Optional import torch @@ -12,46 +10,14 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_K, - GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, - GPTQ_MARLIN_SUPPORTED_NUM_BITS, GPTQ_MARLIN_SUPPORTED_SYM, - GPTQ_MARLIN_TILE) + check_marlin_supported, marlin_make_empty_g_idx, marlin_make_workspace, + marlin_permute_scales, marlin_sort_g_idx, replace_tensor, + verify_marlin_supported, verify_marlin_supports_shape) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.platforms import current_platform logger = init_logger(__name__) -# Permutations for Marlin scale shuffling -def get_scale_perms(num_bits: int): - scale_perm: List[int] = [] - for i in range(8): - scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single: List[int] = [] - for i in range(4): - scale_perm_single.extend( - [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) - return scale_perm, scale_perm_single - - -def get_pack_factor(num_bits: int): - assert (num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS - ), f"Unsupported num_bits = {num_bits}" - return 32 // num_bits - - -def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int, - group_size: int, num_bits: int): - scale_perm, scale_perm_single = get_scale_perms(num_bits) - if group_size < size_k and group_size != -1: - s = s.reshape((-1, len(scale_perm)))[:, scale_perm] - else: - s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] - s = s.reshape((-1, size_n)).contiguous() - - return s - - class GPTQMarlinConfig(QuantizationConfig): """Config class for GPTQ Marlin""" @@ -63,33 +29,16 @@ def __init__(self, weight_bits: int, group_size: int, desc_act: bool, desc_act = False self.weight_bits = weight_bits + self.pack_factor = 32 // self.weight_bits # packed into int32 self.group_size = group_size self.desc_act = desc_act self.is_sym = is_sym self.lm_head_quantized = lm_head_quantized - # Verify - if self.weight_bits not in GPTQ_MARLIN_SUPPORTED_NUM_BITS: - raise ValueError( - f"Marlin does not support weight_bits = {self.weight_bits}. " - f"Only weight_bits = {GPTQ_MARLIN_SUPPORTED_NUM_BITS} " - "are supported.") - if self.group_size not in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES: - raise ValueError( - f"Marlin does not support group_size = {self.group_size}. " - f"Only group_sizes = {GPTQ_MARLIN_SUPPORTED_GROUP_SIZES} " - "are supported.") - if self.is_sym not in GPTQ_MARLIN_SUPPORTED_SYM: - raise ValueError( - f"Marlin does not support is_sym = {self.is_sym}. " - f"Only sym = {GPTQ_MARLIN_SUPPORTED_SYM} are supported.") - - # Init - self.pack_factor = get_pack_factor(weight_bits) - self.tile_size = GPTQ_MARLIN_TILE - self.min_thread_n = GPTQ_MARLIN_MIN_THREAD_N - self.min_thread_k = GPTQ_MARLIN_MIN_THREAD_K - self.max_parallel = GPTQ_MARLIN_MAX_PARALLEL + # Verify supported on platform. + verify_marlin_supported(num_bits=self.weight_bits, + group_size=self.group_size, + is_sym=self.is_sym) def __repr__(self) -> str: return (f"GPTQMarlinConfig(weight_bits={self.weight_bits}, " @@ -168,21 +117,10 @@ def is_marlin_compatible(cls, quant_config: Dict[str, Any]): or desc_act is None): return False - # If the capability of the device is too low, cannot convert. - major, minor = current_platform.get_device_capability() - device_capability = major * 10 + minor - if device_capability < cls.get_min_capability(): - return False - - # Otherwise, can convert if model satisfies marlin constraints. - return (num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS - and group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES - and sym in GPTQ_MARLIN_SUPPORTED_SYM) - - -class GPTQMarlinState(Enum): - REPACK = enum.auto() - READY = enum.auto() + return check_marlin_supported(num_bits=num_bits, + group_size=group_size, + is_sym=sym, + min_capability=cls.get_min_capability()) class GPTQMarlinLinearMethod(LinearMethodBase): @@ -206,6 +144,7 @@ def create_weights( **extra_weight_attrs, ) -> None: del output_size + output_size_per_partition = sum(output_partition_sizes) # Normalize group_size if self.quant_config.group_size != -1: @@ -213,31 +152,11 @@ def create_weights( else: group_size = input_size - # Validate dtype - if params_dtype not in [torch.float16, torch.bfloat16]: - raise ValueError(f"The params dtype must be float16 " - f"or bfloat16, but got {params_dtype}") - - # Validate output_size_per_partition - output_size_per_partition = sum(output_partition_sizes) - if output_size_per_partition % self.quant_config.min_thread_n != 0: - raise ValueError( - f"Weight output_size_per_partition = " - f"{output_size_per_partition} is not divisible by " - f" min_thread_n = {self.quant_config.min_thread_n}.") - - # Validate input_size_per_partition - if input_size_per_partition % self.quant_config.min_thread_k != 0: - raise ValueError( - f"Weight input_size_per_partition = " - f"{input_size_per_partition} is not divisible " - f"by min_thread_k = {self.quant_config.min_thread_k}.") - - if (group_size < input_size - and input_size_per_partition % group_size != 0): - raise ValueError( - f"Weight input_size_per_partition = {input_size_per_partition}" - f" is not divisible by group_size = {group_size}.") + verify_marlin_supports_shape( + output_size_per_partition=output_size_per_partition, + input_size_per_partition=input_size_per_partition, + input_size=input_size, + group_size=group_size) # Detect sharding of scales/zp @@ -303,11 +222,6 @@ def create_weights( }, ) - g_idx_sort_indices = torch.empty( - g_idx.shape, - dtype=torch.int32, - ) - # Scales scales = Parameter( torch.empty( @@ -347,25 +261,50 @@ def create_weights( }, ) - # Allocate marlin workspace - max_workspace_size = ( - output_size_per_partition // - self.quant_config.min_thread_n) * self.quant_config.max_parallel - workspace = torch.zeros(max_workspace_size, - dtype=torch.int, - requires_grad=False) - layer.register_parameter("qweight", qweight) layer.register_parameter("g_idx", g_idx) layer.register_parameter("scales", scales) layer.register_parameter("qzeros", qzeros) - layer.g_idx_sort_indices = g_idx_sort_indices - layer.workspace = workspace layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition layer.input_size = input_size layer.is_k_full = is_k_full - layer.marlin_state = GPTQMarlinState.REPACK + + # Checkpoints are serialized in AutoGPTQ format, which is different from the + # marlin format. This function is called after the weights are loaded. + # Here, we handle the repacking, including the activation reordering case. + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + device = layer.qweight.device + # Allocate marlin workspace + layer.workspace = marlin_make_workspace( + layer.output_size_per_partition, device) + + # Handle sorting for activation reordering if needed. + if self.quant_config.desc_act: + g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.g_idx) + layer.g_idx_sort_indices = g_idx_sort_indices + replace_tensor(layer, "g_idx", g_idx) + else: + layer.g_idx = marlin_make_empty_g_idx(device) + layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) + + # Repack weights from autogptq format to marlin format. + marlin_qweight = ops.gptq_marlin_repack( + layer.qweight, + perm=layer.g_idx_sort_indices, + size_k=layer.input_size_per_partition, + size_n=layer.output_size_per_partition, + num_bits=self.quant_config.weight_bits) + replace_tensor(layer, "qweight", marlin_qweight) + + # Permute scales from autogptq format to marlin format. + marlin_scales = marlin_permute_scales( + layer.scales, + size_k=(layer.input_size if self.quant_config.desc_act else + layer.input_size_per_partition), + size_n=layer.output_size_per_partition, + group_size=self.quant_config.group_size) + replace_tensor(layer, "scales", marlin_scales) def apply( self, @@ -374,87 +313,19 @@ def apply( bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: reshaped_x = x.reshape(-1, x.shape[-1]) - - size_m = reshaped_x.shape[0] - part_size_n = layer.output_size_per_partition - part_size_k = layer.input_size_per_partition - full_size_k = layer.input_size - - out_shape = x.shape[:-1] + (part_size_n, ) - - if layer.marlin_state == GPTQMarlinState.REPACK: - layer.marlin_state = GPTQMarlinState.READY - - # Newly generated tensors need to replace existing tensors that are - # already registered as parameters by vLLM (and won't be freed) - def replace_tensor(name, new_t): - # It is important to use resize_() here since it ensures - # the same buffer is reused - getattr(layer, name).resize_(new_t.shape) - getattr(layer, name).copy_(new_t) - del new_t - - cur_device = layer.qweight.device - - # Process act_order - if self.quant_config.desc_act: - # Get sorting based on g_idx - g_idx_sort_indices = torch.argsort(layer.g_idx).to(torch.int) - - sorted_g_idx = layer.g_idx[g_idx_sort_indices] - - replace_tensor("g_idx", sorted_g_idx) - replace_tensor("g_idx_sort_indices", g_idx_sort_indices) - - else: - # Reset g_idx related tensors - layer.g_idx = Parameter( - torch.empty(0, dtype=torch.int, device=cur_device), - requires_grad=False, - ) - layer.g_idx_sort_indices = Parameter( - torch.empty(0, dtype=torch.int, device=cur_device), - requires_grad=False, - ) - - # Repack weights - marlin_qweight = ops.gptq_marlin_repack( - layer.qweight, - layer.g_idx_sort_indices, - part_size_k, - part_size_n, - self.quant_config.weight_bits, - ) - replace_tensor("qweight", marlin_qweight) - - # Permute scales - scales_size_k = part_size_k - scales_size_n = part_size_n - if self.quant_config.desc_act: - scales_size_k = full_size_k - - marlin_scales = marlin_permute_scales( - layer.scales, - scales_size_k, - scales_size_n, - self.quant_config.group_size, - self.quant_config.weight_bits, - ) - replace_tensor("scales", marlin_scales) - - output = ops.gptq_marlin_gemm( - reshaped_x, - layer.qweight, - layer.scales, - layer.g_idx, - layer.g_idx_sort_indices, - layer.workspace, - self.quant_config.weight_bits, - size_m, - part_size_n, - part_size_k, - layer.is_k_full, - ) + out_shape = x.shape[:-1] + (layer.output_size_per_partition, ) + + output = ops.gptq_marlin_gemm(reshaped_x, + layer.qweight, + layer.scales, + g_idx=layer.g_idx, + perm=layer.g_idx_sort_indices, + workspace=layer.workspace, + num_bits=self.quant_config.weight_bits, + size_m=reshaped_x.shape[0], + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + is_k_full=layer.is_k_full) if bias is not None: output.add_(bias) # In-place add diff --git a/vllm/model_executor/layers/quantization/utils/marlin_24_perms.py b/vllm/model_executor/layers/quantization/utils/marlin_24_perms.py deleted file mode 100644 index 93f65a20d4e4a..0000000000000 --- a/vllm/model_executor/layers/quantization/utils/marlin_24_perms.py +++ /dev/null @@ -1,60 +0,0 @@ -"""This file is used for /tests and /benchmarks""" -from typing import Dict, List - -import numpy -import torch - - -# Precompute permutations for Marlin24 weight and scale shuffling # noqa: E501 -# -# Marlin works on [16*2,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501 -# with the tensor-core format that is described here: -# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501 -# -# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501 -# (without the need to use ldmatrix instructions) # noqa: E501 -def get_perms_24(num_bits: int): - perm_list: List[int] = [] - for i in range(32): - perm1: List[int] = [] - col = i // 4 - col_o = col // 2 - for block in [0, 1]: - for row in [ - 2 * (i % 4), - 2 * (i % 4) + 1, - 2 * (i % 4 + 4), - 2 * (i % 4 + 4) + 1, - ]: - perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + - 4 * block) - for j in range(4): - perm_list.extend([p + 1 * j for p in perm1]) - perm = numpy.array(perm_list) - - if num_bits == 4: - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - elif num_bits == 8: - interleave = numpy.array([0, 2, 1, 3]) - else: - raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits)) - - perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() - perm = torch.from_numpy(perm) - scale_perm: List[int] = [] - for i in range(8): - scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]]) - scale_perm_single: List[int] = [] - for i in range(8): - scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]]) - return perm, scale_perm, scale_perm_single - - -marlin_24_perm: Dict[int, torch.Tensor] = {} -marlin_24_scale_perm: Dict[int, List[int]] = {} -marlin_24_scale_perm_single: Dict[int, List[int]] = {} -for num_bits in [4, 8]: - perm_24, scale_perm_24, scale_perm_single_24 = get_perms_24(num_bits) - marlin_24_perm[num_bits] = perm_24 - marlin_24_scale_perm[num_bits] = scale_perm_24 - marlin_24_scale_perm_single[num_bits] = scale_perm_single_24 diff --git a/vllm/model_executor/layers/quantization/utils/marlin_perms.py b/vllm/model_executor/layers/quantization/utils/marlin_perms.py deleted file mode 100644 index db5e6857a8846..0000000000000 --- a/vllm/model_executor/layers/quantization/utils/marlin_perms.py +++ /dev/null @@ -1,60 +0,0 @@ -"""This file is used for /tests and /benchmarks""" -from typing import Dict, List - -import numpy -import torch - - -# Precompute permutations for Marlin weight and scale shuffling # noqa: E501 -# -# Marlin works on [16,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501 -# with the tensor-core format that is described here: -# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501 -# -# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501 -# (without the need to use ldmatrix instructions) # noqa: E501 -def get_perms(num_bits: int): - perm_list: List[int] = [] - for i in range(32): - perm1: List[int] = [] - col = i // 4 - for block in [0, 1]: - for row in [ - 2 * (i % 4), - 2 * (i % 4) + 1, - 2 * (i % 4 + 4), - 2 * (i % 4 + 4) + 1, - ]: - perm1.append(16 * row + col + 8 * block) - for j in range(4): - perm_list.extend([p + 256 * j for p in perm1]) - - perm = numpy.array(perm_list) - - if num_bits == 4: - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - elif num_bits == 8: - interleave = numpy.array([0, 2, 1, 3]) - else: - raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) - - perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() - perm = torch.from_numpy(perm) - scale_perm: List[int] = [] - for i in range(8): - scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single: List[int] = [] - for i in range(4): - scale_perm_single.extend( - [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) - return perm, scale_perm, scale_perm_single - - -marlin_perm: Dict[int, torch.Tensor] = {} -marlin_scale_perm: Dict[int, List[int]] = {} -marlin_scale_perm_single: Dict[int, List[int]] = {} -for num_bits in [4, 8]: - perm, scale_perm, scale_perm_single = get_perms(num_bits) - marlin_perm[num_bits] = perm - marlin_scale_perm[num_bits] = scale_perm - marlin_scale_perm_single[num_bits] = scale_perm_single diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 9886245269ad3..612c5fd20093a 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -1,21 +1,9 @@ -"""This file is used for /tests and /benchmarks""" -import random -from typing import Optional +from typing import List, Optional, Tuple -import numpy import torch from vllm import _custom_ops as ops -from vllm.model_executor.layers.quantization.utils.format_24 import ( - mask_creator, sparse_semi_structured_from_dense_cutlass) -from vllm.model_executor.layers.quantization.utils.marlin_24_perms import ( - marlin_24_perm, marlin_24_scale_perm, marlin_24_scale_perm_single) -from vllm.model_executor.layers.quantization.utils.marlin_perms import ( - marlin_perm, marlin_scale_perm, marlin_scale_perm_single) -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - get_pack_factor, quantize_weights, sort_weights) from vllm.platforms import current_platform -from vllm.utils import print_warning_once GPTQ_MARLIN_TILE = 16 GPTQ_MARLIN_MIN_THREAD_N = 64 @@ -25,135 +13,110 @@ GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4, 8] GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] GPTQ_MARLIN_SUPPORTED_SYM = [True] - - -def is_marlin_supported(): - capability = current_platform.get_device_capability() - return capability[0] >= 8 - - -def apply_fp8_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - workspace: torch.Tensor, - size_n: int, - size_k: int, - bias: Optional[torch.Tensor], -) -> torch.Tensor: - # For GPUs that lack FP8 hardware support, we can leverage the - # Marlin kernel for fast weight-only FP8 quantization - - reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (size_n, ) - - output = ops.fp8_marlin_gemm( - a=reshaped_x, - b_q_weight=weight, - b_scales=weight_scale, - workspace=workspace, - num_bits=8, - size_m=reshaped_x.shape[0], - size_n=size_n, - size_k=size_k, - ) - - if bias is not None: - output.add_(bias) # In-place add - - return output.reshape(out_shape) - - -def prepare_fp8_layer_for_marlin(layer: torch.nn.Module) -> None: - print_warning_once( - "Your GPU does not have native support for FP8 computation but " - "FP8 quantization is being used. Weight-only FP8 compression will " - "be used leveraging the Marlin kernel. This may degrade " - "performance for compute-heavy workloads.") - - part_size_n = layer.output_size_per_partition - part_size_k = layer.input_size_per_partition - - device = layer.weight.device - - # WEIGHTS - # Repack weights to gptq format (packed int32 elements) - packed_gptq_qweight = pack_fp8_to_int32(layer.weight) - - # Repack weights to marlin format - marlin_qweight = ops.gptq_marlin_repack( - b_q_weight=packed_gptq_qweight, - perm=torch.empty(0, dtype=torch.int, device=device), - size_k=part_size_k, - size_n=part_size_n, - num_bits=8, - ) - layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) - - # WEIGHT SCALES - # Currently Marlin doesn't support per-tensor scales, so we - # expand it to channelwise - scales = layer.weight_scale.repeat(1, part_size_n).to( - layer.orig_dtype).to(device) - # Permute scales - num_bits = 8 - marlin_scales = marlin_permute_scales( - s=scales, - size_k=part_size_k, - size_n=part_size_n, - group_size=-1, - scale_perm=marlin_scale_perm[num_bits], - scale_perm_single=marlin_scale_perm_single[num_bits]) - layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) - - # Allocate marlin workspace - max_workspace_size = (part_size_n // +GTPQ_MARLIN_UNSUPPORTED_GROUP_SIZE_ACT_ORDER = [-1] + + +def check_marlin_supported(num_bits: int, group_size: int, is_sym: bool, + min_capability: int) -> bool: + + # If the capability of the device is too low, cannot convert. + major, minor = current_platform.get_device_capability() + device_capability = major * 10 + minor + if device_capability < min_capability: + return False + + return (device_capability >= min_capability + and num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS + and group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES + and is_sym in GPTQ_MARLIN_SUPPORTED_SYM) + + +def verify_marlin_supported(num_bits: int, group_size: Optional[int], + is_sym: bool) -> None: + + if num_bits not in GPTQ_MARLIN_SUPPORTED_NUM_BITS: + raise ValueError( + f"Marlin does not support weight_bits = {num_bits}. " + f"Only weight_bits = {GPTQ_MARLIN_SUPPORTED_NUM_BITS} " + "are supported.") + if (group_size is None + or group_size not in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES): + raise ValueError( + f"Marlin does not support group_size = {group_size}. " + f"Only group_sizes = {GPTQ_MARLIN_SUPPORTED_GROUP_SIZES} " + "are supported.") + if is_sym not in GPTQ_MARLIN_SUPPORTED_SYM: + raise ValueError( + f"Marlin does not support is_sym = is_sym. " + f"Only sym = {GPTQ_MARLIN_SUPPORTED_SYM} are supported.") + + +def verify_marlin_supports_shape(output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, group_size: int) -> None: + + # Validate output_size_per_partition + if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0: + raise ValueError(f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq.") + + # Validate input_size_per_partition + if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0: + raise ValueError(f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible " + f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq.") + + if (group_size < input_size + and input_size_per_partition % group_size != 0): + raise ValueError( + f"Weight input_size_per_partition = {input_size_per_partition}" + f" is not divisible by group_size = {group_size}." + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq.") + + +def marlin_make_workspace(output_size_per_partition: int, + device: torch.device) -> torch.Tensor: + max_workspace_size = (output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL - workspace = torch.zeros(max_workspace_size, - dtype=torch.int, - device=device, - requires_grad=False) - - layer.workspace = workspace - - -def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE): - assert q_w.shape == (size_k, size_n) - assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" - assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" - - # Permute weights to 16x64 marlin tiles - q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) - q_w = q_w.permute((0, 2, 1, 3)) - q_w = q_w.reshape((size_k // tile, size_n * tile)) - q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape) + return torch.zeros(max_workspace_size, + dtype=torch.int, + device=device, + requires_grad=False) - return q_w +def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), + requires_grad=False) -def marlin_weights(q_w, size_k, size_n, num_bits, perm): - # Permute - q_w = marlin_permute_weights(q_w, size_k, size_n, perm) - # Pack - pack_factor = get_pack_factor(num_bits) - orig_device = q_w.device +def marlin_sort_g_idx( + g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) + return g_idx[g_idx_sort_indices], g_idx_sort_indices - q_w = q_w.cpu().numpy().astype(numpy.uint32) - q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), - dtype=numpy.uint32) - for i in range(pack_factor): - q_packed |= q_w[:, i::pack_factor] << num_bits * i +def get_scale_perms(): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: List[int] = [] + for i in range(4): + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single - q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device) - return q_packed +def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int, + group_size: int) -> torch.Tensor: - -def marlin_permute_scales(s, size_k, size_n, group_size, scale_perm, - scale_perm_single): + scale_perm, scale_perm_single = get_scale_perms() if group_size < size_k and group_size != -1: s = s.reshape((-1, len(scale_perm)))[:, scale_perm] else: @@ -163,180 +126,44 @@ def marlin_permute_scales(s, size_k, size_n, group_size, scale_perm, return s -def marlin_quantize( - w: torch.Tensor, - num_bits: int, - group_size: int, - act_order: bool, -): - size_k, size_n = w.shape - - # Normalize group_size - if group_size == -1: - group_size = size_k - assert group_size <= size_k - - # Quantize (and apply act_order if provided) - w_ref, q_w, s, g_idx, rand_perm = quantize_weights(w, num_bits, group_size, - act_order) - - # For act_order, sort the "weights" and "g_idx" so that group ids are - # increasing - sort_indices = torch.empty(0, dtype=torch.int, device=w.device) - if act_order: - q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) - - # Reformat to marlin - marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, - marlin_perm[num_bits]) - marlin_s = marlin_permute_scales(s, size_k, size_n, group_size, - marlin_scale_perm[num_bits], - marlin_scale_perm_single[num_bits]) - - # Create result - res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] - for i in range(len(res_list)): - res_list[i] = res_list[i].to(w.device) - - return res_list - - -def inject_24(w, size_k, size_n): - assert w.shape == (size_k, size_n) - - mask = mask_creator(w.t()).t().cuda().bool() - - return (mask * w).contiguous(), mask.contiguous() - - -def check_24(w, num_rows_to_sample=50, _verbose=False): - BLOCK_SIZE = 4 - MAX_NON_ZEROS = 2 - - w = w.t().contiguous() - - print("check_24: w.shape = {}".format(w.shape)) - - num_rows, num_cols = w.shape - sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample) - if _verbose: - print(f"Sampled row idxs = {sampled_row_idxs}") - - total_segments = 0 - non_24_segments = 0 - for i in sampled_row_idxs: - for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE): - total_segments += 1 - block = w[i, j:j + BLOCK_SIZE] - num_nonzero = torch.count_nonzero(block) - if num_nonzero > MAX_NON_ZEROS: - print("i = {} j = {} block = {}".format(i, j, block)) - non_24_segments += 1 - - print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.") - - -def compress_quantized_24_weight(q_24, size_k, size_n, num_bits): - assert q_24.shape == (size_k, size_n) - - # Remove zp to normalize over 0 - max_q_val = (1 << num_bits) - 1 - zp = (max_q_val + 1) // 2 - q_24_no_zp = q_24 - zp - - # Compress - q_24_no_zp = q_24_no_zp.t().contiguous() - q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass( - q_24_no_zp) - q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous() - - # Restore zp - q_24_comp = q_24_no_zp_comp + zp - - # Resize meta to its actual shape (without moving any data) - meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2) - - return q_24_comp, meta - - -def marlin_24_quantize( - w: torch.Tensor, - num_bits: int, - group_size: int, -): - size_k, size_n = w.shape - - # Normalize group_size - if group_size == -1: - group_size = size_k - assert group_size <= size_k - - # Inject 2:4 sparsity - w_24, mask_24 = inject_24(w, size_k, size_n) - - # Quantize - w_24_ref, q_w_24, s, g_idx, rand_perm = quantize_weights(w_24, - num_bits, - group_size, - act_order=False) - - # Compress quantized weight - q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, - num_bits) - size_k_comp = size_k // 2 - - # Reformat to marlin - marlin_24_q_w_comp = marlin_weights(q_w_24_comp, size_k_comp, size_n, - num_bits, marlin_24_perm[num_bits]) - marlin_24_s = marlin_permute_scales(s, size_k, size_n, group_size, - marlin_24_scale_perm[num_bits], - marlin_24_scale_perm_single[num_bits]) - - # Create result - res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s] - for i in range(len(res_list)): - res_list[i] = res_list[i].to(w.device) - - return res_list - - -def compute_max_diff(output, output_ref): - return torch.mean(torch.abs(output - output_ref)) / torch.mean( - torch.abs(output_ref)) - - -class MarlinWorkspace: - - def __init__(self, out_features, min_thread_n, max_parallel): - assert (out_features % min_thread_n == 0), ( - "out_features = {} is undivisible by min_thread_n = {}".format( - out_features, min_thread_n)) - - max_workspace_size = ((out_features // min_thread_n) * max_parallel) - - self.scratch = torch.zeros(max_workspace_size, - dtype=torch.int, - device="cuda") - - -def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: - """ - Repack FP8 weights to gptq format (packed int32 elements) - """ - assert fp8_tensor.dtype == torch.float8_e4m3fn - assert fp8_tensor.shape[0] % 4 == 0 - - # Reshape to prepare for packing - reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) - - # Convert fp8 to uint8 (byte) representation - byte_tensor = reshaped.view(torch.uint8) +# Newly generated tensors need to replace existing tensors that are +# already registered as parameters by vLLM (and won't be freed) +def replace_tensor(layer: torch.nn.Module, name: str, + new_t: torch.Tensor) -> None: + # It is important to use resize_() here since it ensures + # the same buffer is reused + getattr(layer, name).resize_(new_t.shape) + getattr(layer, name).copy_(new_t) + del new_t + + +def apply_marlin_linear(input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + num_bits: int, + output_size_per_partition: int, + input_size_per_partition: int, + is_k_full: bool, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (output_size_per_partition, ) + + output = ops.gptq_marlin_gemm(reshaped_x, + weight, + weight_scale, + g_idx, + g_idx_sort_indices, + workspace, + num_bits, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + is_k_full=is_k_full) - # Pack 4 uint8 values into one int32 - packed = (byte_tensor[:, 0].to(torch.int32) | - (byte_tensor[:, 1].to(torch.int32) << 8) | - (byte_tensor[:, 2].to(torch.int32) << 16) | - (byte_tensor[:, 3].to(torch.int32) << 24)) + if bias is not None: + output.add_(bias) # In-place add - return packed.view(fp8_tensor.shape[0] // 4, - *fp8_tensor.shape[1:]).contiguous() + return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py new file mode 100644 index 0000000000000..e93eb747ba2eb --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py @@ -0,0 +1,109 @@ +from typing import Optional + +import torch + +import vllm._custom_ops as ops +from vllm.platforms import current_platform +from vllm.utils import print_warning_once + +from .marlin_utils import marlin_make_workspace, marlin_permute_scales + + +def is_fp8_marlin_supported(): + capability = current_platform.get_device_capability() + return capability[0] >= 8 + + +def apply_fp8_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor], +) -> torch.Tensor: + # For GPUs that lack FP8 hardware support, we can leverage the + # Marlin kernel for fast weight-only FP8 quantization + + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (size_n, ) + + output = ops.fp8_marlin_gemm( + a=reshaped_x, + b_q_weight=weight, + b_scales=weight_scale, + workspace=workspace, + num_bits=8, + size_m=reshaped_x.shape[0], + size_n=size_n, + size_k=size_k, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +def prepare_fp8_layer_for_marlin(layer: torch.nn.Module) -> None: + print_warning_once( + "Your GPU does not have native support for FP8 computation but " + "FP8 quantization is being used. Weight-only FP8 compression will " + "be used leveraging the Marlin kernel. This may degrade " + "performance for compute-heavy workloads.") + + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + + device = layer.weight.device + + # WORKSPACE + layer.workspace = marlin_make_workspace(part_size_n, device) + + # WEIGHT + # Repack weights to marlin format + marlin_qweight = ops.gptq_marlin_repack(b_q_weight=pack_fp8_to_int32( + layer.weight), + perm=torch.empty(0, + dtype=torch.int, + device=device), + size_k=part_size_k, + size_n=part_size_n, + num_bits=8) + layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) + + # WEIGHT SCALES + # Currently Marlin doesn't support per-tensor scales, so we + # expand it to channelwise + scales = layer.weight_scale.repeat(1, part_size_n).to( + layer.orig_dtype).to(device) + # Permute scales + marlin_scales = marlin_permute_scales(s=scales, + size_k=part_size_k, + size_n=part_size_n, + group_size=-1) + layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) + + +def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: + """ + Repack FP8 weights to gptq format (packed int32 elements) + """ + assert fp8_tensor.dtype == torch.float8_e4m3fn + assert fp8_tensor.shape[0] % 4 == 0 + + # Reshape to prepare for packing + reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) + + # Convert fp8 to uint8 (byte) representation + byte_tensor = reshaped.view(torch.uint8) + + # Pack 4 uint8 values into one int32 + packed = (byte_tensor[:, 0].to(torch.int32) | + (byte_tensor[:, 1].to(torch.int32) << 8) | + (byte_tensor[:, 2].to(torch.int32) << 16) | + (byte_tensor[:, 3].to(torch.int32) << 24)) + + return packed.view(fp8_tensor.shape[0] // 4, + *fp8_tensor.shape[1:]).contiguous() diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py new file mode 100644 index 0000000000000..1773748a0f228 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py @@ -0,0 +1,120 @@ +"""Utility functions used for tests and benchmarks""" + +from typing import List + +import numpy +import torch + +from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales +from .quant_utils import get_pack_factor, quantize_weights, sort_weights + + +class MarlinWorkspace: + + def __init__(self, out_features, min_thread_n, max_parallel): + assert (out_features % min_thread_n == 0), ( + "out_features = {} is undivisible by min_thread_n = {}".format( + out_features, min_thread_n)) + + max_workspace_size = ((out_features // min_thread_n) * max_parallel) + + self.scratch = torch.zeros(max_workspace_size, + dtype=torch.int, + device="cuda") + + +def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE): + assert q_w.shape == (size_k, size_n) + assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" + assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" + + # Permute weights to 16x64 marlin tiles + q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) + q_w = q_w.permute((0, 2, 1, 3)) + q_w = q_w.reshape((size_k // tile, size_n * tile)) + + q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape) + + return q_w + + +def marlin_weights(q_w, size_k, size_n, num_bits, perm): + # Permute + q_w = marlin_permute_weights(q_w, size_k, size_n, perm) + + # Pack + pack_factor = get_pack_factor(num_bits) + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), + dtype=numpy.uint32) + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device) + + return q_packed + + +def get_weight_perm(num_bits: int): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = numpy.array(perm_list) + + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int, + act_order: bool): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Quantize (and apply act_order if provided) + w_ref, q_w, s, g_idx, rand_perm = quantize_weights(w, num_bits, group_size, + act_order) + + # For act_order, sort the "weights" and "g_idx" so that group ids are + # increasing + sort_indices = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) + + # Reformat to marlin + weight_perm = get_weight_perm(num_bits) + marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) + + # Create result + res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/vllm/model_executor/layers/quantization/utils/format_24.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py similarity index 71% rename from vllm/model_executor/layers/quantization/utils/format_24.py rename to vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py index 01c8cf789204b..648c32249a571 100644 --- a/vllm/model_executor/layers/quantization/utils/format_24.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py @@ -1,9 +1,14 @@ -# -# Modified by Roberto Lopez Castro (roberto.lopez.castro@udc.es). -# +"""Utility functions used for tests and benchmarks""" +import random +from typing import List + +import numpy import torch +from .marlin_utils_test import marlin_weights +from .quant_utils import quantize_weights + # This is PyTorch implementation of main part of reorder_meta() # function, from tools/util/include/cutlass/util/host_reorder.h file @@ -306,3 +311,155 @@ def mask_creator(tensor): mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape) return mask + + +def inject_24(w, size_k, size_n): + assert w.shape == (size_k, size_n) + + mask = mask_creator(w.t()).t().cuda().bool() + + return (mask * w).contiguous(), mask.contiguous() + + +def check_24(w, num_rows_to_sample=50, _verbose=False): + BLOCK_SIZE = 4 + MAX_NON_ZEROS = 2 + + w = w.t().contiguous() + + print("check_24: w.shape = {}".format(w.shape)) + + num_rows, num_cols = w.shape + sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample) + if _verbose: + print(f"Sampled row idxs = {sampled_row_idxs}") + + total_segments = 0 + non_24_segments = 0 + for i in sampled_row_idxs: + for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE): + total_segments += 1 + block = w[i, j:j + BLOCK_SIZE] + num_nonzero = torch.count_nonzero(block) + if num_nonzero > MAX_NON_ZEROS: + print("i = {} j = {} block = {}".format(i, j, block)) + non_24_segments += 1 + + print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.") + + +def compress_quantized_24_weight(q_24, size_k, size_n, num_bits): + assert q_24.shape == (size_k, size_n) + + # Remove zp to normalize over 0 + max_q_val = (1 << num_bits) - 1 + zp = (max_q_val + 1) // 2 + q_24_no_zp = q_24 - zp + + # Compress + q_24_no_zp = q_24_no_zp.t().contiguous() + q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass( + q_24_no_zp) + q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous() + + # Restore zp + q_24_comp = q_24_no_zp_comp + zp + + # Resize meta to its actual shape (without moving any data) + meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2) + + return q_24_comp, meta + + +def get_scale_perms_24(): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]]) + scale_perm_single: List[int] = [] + for i in range(8): + scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]]) + return scale_perm, scale_perm_single + + +def get_weight_perm_24(num_bits: int): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + col_o = col // 2 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + + 4 * block) + for j in range(4): + perm_list.extend([p + 1 * j for p in perm1]) + perm = numpy.array(perm_list) + + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_permute_scales_24(s: torch.Tensor, size_k: int, size_n: int, + group_size: int) -> torch.Tensor: + + scale_perm, scale_perm_single = get_scale_perms_24() + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + + return s + + +def marlin_24_quantize( + w: torch.Tensor, + num_bits: int, + group_size: int, +): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Inject 2:4 sparsity + w_24, mask_24 = inject_24(w, size_k, size_n) + + # Quantize + w_24_ref, q_w_24, s, g_idx, rand_perm = quantize_weights(w_24, + num_bits, + group_size, + act_order=False) + + # Compress quantized weight + q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, + num_bits) + size_k_comp = size_k // 2 + + # Reformat to marlin + weight_perm = get_weight_perm_24(num_bits) + marlin_24_q_w_comp = marlin_weights(q_w_24_comp, size_k_comp, size_n, + num_bits, weight_perm) + marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size) + + # Create result + res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list From 52b7fcb35a6f8b57429431e929884c05d8266023 Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Thu, 11 Jul 2024 09:17:07 -0700 Subject: [PATCH 24/26] Benchmark: add H100 suite (#6047) --- .../benchmark-pipeline.yaml | 35 +++++++++---------- .../run-benchmarks-suite.sh | 28 ++++++++++++--- 2 files changed, 40 insertions(+), 23 deletions(-) diff --git a/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml b/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml index 2b25c954b5c5c..02c0ee534d72c 100644 --- a/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml +++ b/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml @@ -11,7 +11,7 @@ steps: - sh - .buildkite/nightly-benchmarks/scripts/wait-for-image.sh - wait - - label: "A100 Benchmark" + - label: "A100" agents: queue: A100 plugins: @@ -42,21 +42,20 @@ steps: - name: devshm emptyDir: medium: Memory - # - label: "H100: NVIDIA SMI" - # agents: - # queue: H100 - # plugins: - # - docker#v5.11.0: - # image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT - # command: - # - bash - # - .buildkite/nightly-benchmarks/run-benchmarks-suite.sh - # mount-buildkite-agent: true - # propagate-environment: true - # propagate-uid-gid: false - # ipc: host - # gpus: all - # environment: - # - VLLM_USAGE_SOURCE - # - HF_TOKEN + - label: "H100" + agents: + queue: H100 + plugins: + - docker#v5.11.0: + image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + command: + - bash + - .buildkite/nightly-benchmarks/run-benchmarks-suite.sh + mount-buildkite-agent: true + propagate-environment: true + ipc: host + gpus: all + environment: + - VLLM_USAGE_SOURCE + - HF_TOKEN diff --git a/.buildkite/nightly-benchmarks/run-benchmarks-suite.sh b/.buildkite/nightly-benchmarks/run-benchmarks-suite.sh index 021473f76d0e5..04b02adf3644c 100644 --- a/.buildkite/nightly-benchmarks/run-benchmarks-suite.sh +++ b/.buildkite/nightly-benchmarks/run-benchmarks-suite.sh @@ -54,7 +54,7 @@ wait_for_server() { # wait for vllm server to start # return 1 if vllm server crashes timeout 1200 bash -c ' - until curl localhost:8000/v1/completions; do + until curl -X POST localhost:8000/v1/completions; do sleep 1 done' && return 0 || return 1 } @@ -73,8 +73,17 @@ kill_gpu_processes() { echo "All GPU processes have been killed." fi + # Sometimes kill with pid doesn't work properly, we can also kill all process running python or python3 + # since we are in container anyway + pkill -9 -f python + pkill -9 -f python3 + # waiting for GPU processes to be fully killed - sleep 10 + # loop while nvidia-smi returns any processes + while [ -n "$(nvidia-smi --query-compute-apps=pid --format=csv,noheader)" ]; do + sleep 1 + echo "Waiting for GPU processes to be killed" + done # remove vllm config file rm -rf ~/.config/vllm @@ -90,12 +99,19 @@ upload_to_buildkite() { # upload the benchmarking results to buildkite # if the agent binary is not found, skip uploading the results, exit 0 - if [ ! -f /workspace/buildkite-agent ]; then + # Check if buildkite-agent is available in the PATH or at /workspace/buildkite-agent + if command -v buildkite-agent >/dev/null 2>&1; then + BUILDKITE_AGENT_COMMAND="buildkite-agent" + elif [ -f /workspace/buildkite-agent ]; then + BUILDKITE_AGENT_COMMAND="/workspace/buildkite-agent" + else echo "buildkite-agent binary not found. Skip uploading the results." return 0 fi - /workspace/buildkite-agent annotate --style "info" --context "benchmark-results" < $RESULTS_FOLDER/benchmark_results.md - /workspace/buildkite-agent artifact upload "$RESULTS_FOLDER/*" + + # Use the determined command to annotate and upload artifacts + $BUILDKITE_AGENT_COMMAND annotate --style "info" --context "$BUILDKITE_LABEL-benchmark-results" < $RESULTS_FOLDER/benchmark_results.md + $BUILDKITE_AGENT_COMMAND artifact upload "$RESULTS_FOLDER/*" } run_latency_tests() { @@ -269,6 +285,7 @@ run_serving_tests() { echo "Running test case $test_name" echo "Server command: $server_command" eval "$server_command" & + server_pid=$! # wait until the server is alive wait_for_server @@ -318,6 +335,7 @@ run_serving_tests() { done # clean up + kill -9 $server_pid kill_gpu_processes done } From 1df43de9bb2cceecdc0dc2dc5c650a327aeabe0f Mon Sep 17 00:00:00 2001 From: xwjiang2010 <87673679+xwjiang2010@users.noreply.github.com> Date: Thu, 11 Jul 2024 10:21:10 -0700 Subject: [PATCH 25/26] [bug fix] Fix llava next feature size calculation. (#6339) Signed-off-by: Xiaowei Jiang --- tests/models/test_llava_next.py | 14 +++++++++++++- vllm/model_executor/models/llava_next.py | 18 ++++++++++-------- 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/tests/models/test_llava_next.py b/tests/models/test_llava_next.py index 581cbcf9068fe..163741a5719c2 100644 --- a/tests/models/test_llava_next.py +++ b/tests/models/test_llava_next.py @@ -1,8 +1,10 @@ from typing import List, Optional, Tuple import pytest -from transformers import AutoTokenizer +from transformers import AutoConfig, AutoTokenizer +from vllm.model_executor.models.llava_next import ( + get_llava_next_image_feature_size) from vllm.multimodal.utils import rescale_image_size from vllm.sequence import SampleLogprobs @@ -120,3 +122,13 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, name_0="hf", name_1="vllm", ) + + +@pytest.mark.parametrize("height_and_width_and_result", [(1669, 2560, 2144), + (183, 488, 776)]) +def test_image_feature_size(height_and_width_and_result): + height, width, result = height_and_width_and_result + config = AutoConfig.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf") + assert get_llava_next_image_feature_size(config, + input_height=height, + input_width=width) == result diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 7e06f1e95dab1..9369ec89fa9d5 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -74,19 +74,21 @@ def _get_llava_next_num_unpadded_features( ) -> Tuple[int, int]: current_height = npatches * num_patch_height current_width = npatches * num_patch_width + current_height = torch.tensor(current_height).to("cuda") + current_width = torch.tensor(current_width).to("cuda") aspect_ratio: float = width / height current_aspect_ratio: float = current_width / current_height if aspect_ratio > current_aspect_ratio: - new_height = (height * current_width) // width - if new_height % 2 == 1: - new_height += 1 - current_height = new_height + scale_factor = current_width / width + new_height = int(height * scale_factor) + padding = (current_height - new_height) // 2 + current_height -= padding * 2 else: - new_width = (width * current_height) // height - if new_width % 2 == 1: - new_width += 1 - current_width = new_width + scale_factor = current_height / height + new_width = int(width * scale_factor) + padding = (current_width - new_width) // 2 + current_width -= padding * 2 unpadded_features = current_height * current_width newline_features = current_height From 2d23b42d9255f724f955a1cf91ed78c983854737 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 11 Jul 2024 11:38:40 -0700 Subject: [PATCH 26/26] [doc] update pipeline parallel in readme (#6347) --- README.md | 2 +- docs/source/index.rst | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index cced85f17e257..dac4b513cd2a2 100644 --- a/README.md +++ b/README.md @@ -56,7 +56,7 @@ vLLM is flexible and easy to use with: - Seamless integration with popular Hugging Face models - High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more -- Tensor parallelism support for distributed inference +- Tensor parallelism and pipieline parallelism support for distributed inference - Streaming outputs - OpenAI-compatible API server - Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs diff --git a/docs/source/index.rst b/docs/source/index.rst index 67c039f25e98d..174d91b8d6a01 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -38,7 +38,7 @@ vLLM is flexible and easy to use with: * Seamless integration with popular HuggingFace models * High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more -* Tensor parallelism support for distributed inference +* Tensor parallelism and pipieline parallelism support for distributed inference * Streaming outputs * OpenAI-compatible API server * Support NVIDIA GPUs and AMD GPUs