From 09a8d6307799834cf6e66a8b9346c041914a503e Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 14 Jul 2024 21:20:51 -0700 Subject: [PATCH] [core][distributed] simplify code to support pipeline parallel (#6406) --- .buildkite/test-pipeline.yaml | 4 +- .../test_basic_correctness.py | 11 +++- vllm/model_executor/models/gpt2.py | 47 +++++++--------- vllm/model_executor/models/llama.py | 50 ++++++++--------- vllm/model_executor/models/utils.py | 56 +++++++++++++++++++ 5 files changed, 107 insertions(+), 61 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 2d5bbbf07cac9..4019cc00fa2b9 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -46,9 +46,7 @@ steps: fast_check: true commands: - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl - - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py - - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py - - VLLM_ATTENTION_BACKEND=FLASHINFER pytest -v -s basic_correctness/test_basic_correctness.py + - pytest -v -s basic_correctness/test_basic_correctness.py - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py - VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index d3e74a4f834a4..ec7c2ba3e3ce0 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -28,10 +28,8 @@ def test_vllm_gc_ed(): assert weak_llm() is None -@pytest.mark.skipif(is_hip() - and os.getenv("VLLM_ATTENTION_BACKEND") == "FLASHINFER", - reason="Flashinfer does not support ROCm/HIP.") @pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS", "FLASHINFER"]) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("enforce_eager", [False, True]) @@ -40,10 +38,17 @@ def test_models( vllm_runner, example_prompts, model: str, + backend: str, dtype: str, max_tokens: int, enforce_eager: bool, ) -> None: + + if backend == "FLASHINFER" and is_hip(): + pytest.skip("Flashinfer does not support ROCm/HIP.") + + os.environ["VLLM_ATTENTION_BACKEND"] = backend + with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index be19f4ba8c71e..d309a2b27f5dd 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -27,7 +27,6 @@ from vllm.config import CacheConfig from vllm.distributed.parallel_state import ( get_pp_group, get_tensor_model_parallel_world_size) -from vllm.distributed.utils import get_pp_indices from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, @@ -42,6 +41,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput +from .utils import is_pp_missing_parameter, make_layers + class GPT2Attention(nn.Module): @@ -183,18 +184,9 @@ def __init__( self.embed_dim = config.hidden_size self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) - self.start_layer, self.end_layer = get_pp_indices( + self.start_layer, self.end_layer, self.h = make_layers( config.num_hidden_layers, - get_pp_group().rank_in_group, - get_pp_group().world_size) - self.h = nn.ModuleList( - [nn.Identity() for _ in range(self.start_layer)] + [ - GPT2Block(config, cache_config, quant_config) - for _ in range(self.start_layer, self.end_layer) - ] + [ - nn.Identity() - for _ in range(self.end_layer, config.num_hidden_layers) - ]) + lambda: GPT2Block(config, cache_config, quant_config)) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) def forward( @@ -291,19 +283,20 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue if not name.startswith("transformer."): name = "transformer." + name - try: - param = params_dict[name] - # The HF's GPT-2 implementation uses Conv1D instead of Linear. - # Because of this, we need to transpose the weights. - # Note(zhuohan): the logic below might break quantized models. - for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]: - if conv1d_weight_name not in name: - continue - if not name.endswith(".weight"): - continue - loaded_weight = loaded_weight.t() - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - except KeyError: + + if is_pp_missing_parameter(name, self): continue + + param = params_dict[name] + # The HF's GPT-2 implementation uses Conv1D instead of Linear. + # Because of this, we need to transpose the weights. + # Note(zhuohan): the logic below might break quantized models. + for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]: + if conv1d_weight_name not in name: + continue + if not name.endswith(".weight"): + continue + loaded_weight = loaded_weight.t() + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 77edcd7402db1..a777d1fbfa802 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -29,8 +29,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import (get_pp_group, get_pp_indices, - get_tensor_model_parallel_rank, +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -51,6 +50,7 @@ from vllm.utils import is_hip, print_warning_once from .interfaces import SupportsLoRA +from .utils import is_pp_missing_parameter, make_layers class LlamaMLP(nn.Module): @@ -262,20 +262,11 @@ def __init__( config.hidden_size, org_num_embeddings=config.vocab_size, ) - self.start_layer, self.end_layer = get_pp_indices( + self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - get_pp_group().rank_in_group, - get_pp_group().world_size) - self.layers = nn.ModuleList( - [nn.Identity() for _ in range(self.start_layer)] + [ - LlamaDecoderLayer(config=config, - cache_config=cache_config, - quant_config=quant_config) - for _ in range(self.start_layer, self.end_layer) - ] + [ - nn.Identity() - for _ in range(self.end_layer, config.num_hidden_layers) - ]) + lambda: LlamaDecoderLayer(config=config, + cache_config=cache_config, + quant_config=quant_config)) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: @@ -455,12 +446,14 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - try: - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - except KeyError: - pass + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break else: # Skip loading extra bias for GPTQ models. @@ -479,13 +472,14 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue else: name = remapped_kv_scale_name - try: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - except KeyError: - pass + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) # If this function is called, it should always initialize KV cache scale # factors (or else raise an exception). Thus, handled exceptions should diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index ef2562b073e6f..a0d2a0286ff67 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -1,3 +1,5 @@ +from typing import Callable, Dict, List, Tuple + import torch from vllm.multimodal import BatchedTensors @@ -39,3 +41,57 @@ def merge_vision_embeddings(input_ids: torch.Tensor, inputs_embeds[mask] = torch.cat(vision_embeddings) return inputs_embeds + + +class PPMissingLayer(torch.nn.Identity): + """ + A placeholder layer for missing layers in a pipeline parallel model. + """ + + def __init__(self, *args, **kwargs): + super().__init__() + + +def make_layers( + num_hidden_layers: int, layer_fn: Callable[[], torch.nn.Module] +) -> Tuple[int, int, torch.nn.ModuleList]: + """Make a list of layers with the given layer function, taking + pipeline parallelism into account. + """ + from vllm.distributed.parallel_state import get_pp_group + from vllm.distributed.utils import get_pp_indices + start_layer, end_layer = get_pp_indices(num_hidden_layers, + get_pp_group().rank_in_group, + get_pp_group().world_size) + modules = torch.nn.ModuleList( + [PPMissingLayer() for _ in range(start_layer)] + + [layer_fn() for _ in range(start_layer, end_layer)] + + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)]) + return start_layer, end_layer, modules + + +# NOTE: don't use lru_cache here because it can prevent garbage collection +_model_to_pp_missing_layer_names: Dict[int, List[str]] = {} + + +def get_pp_missing_layer_names(model: torch.nn.Module) -> List[str]: + """Get the names of the missing layers in a pipeline parallel model.""" + model_id = id(model) + if model_id in _model_to_pp_missing_layer_names: + return _model_to_pp_missing_layer_names[model_id] + + missing_layer_names = [] + for name, module in model.named_modules(): + if isinstance(module, PPMissingLayer): + missing_layer_names.append(name) + _model_to_pp_missing_layer_names[model_id] = missing_layer_names + + return missing_layer_names + + +def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool: + """Check if a parameter is missing in a pipeline parallel model.""" + for missing_layer_name in get_pp_missing_layer_names(model): + if name.startswith(missing_layer_name): + return True + return False