Skip to content

Commit

Permalink
[core][distributed] simplify code to support pipeline parallel (vllm-…
Browse files Browse the repository at this point in the history
  • Loading branch information
youkaichao authored and dtrifiro committed Jul 17, 2024
1 parent 3ec9437 commit 09a8d63
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 61 deletions.
4 changes: 1 addition & 3 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,7 @@ steps:
fast_check: true
commands:
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py
- VLLM_ATTENTION_BACKEND=FLASHINFER pytest -v -s basic_correctness/test_basic_correctness.py
- pytest -v -s basic_correctness/test_basic_correctness.py
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
- VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py
Expand Down
11 changes: 8 additions & 3 deletions tests/basic_correctness/test_basic_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,8 @@ def test_vllm_gc_ed():
assert weak_llm() is None


@pytest.mark.skipif(is_hip()
and os.getenv("VLLM_ATTENTION_BACKEND") == "FLASHINFER",
reason="Flashinfer does not support ROCm/HIP.")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS", "FLASHINFER"])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("enforce_eager", [False, True])
Expand All @@ -40,10 +38,17 @@ def test_models(
vllm_runner,
example_prompts,
model: str,
backend: str,
dtype: str,
max_tokens: int,
enforce_eager: bool,
) -> None:

if backend == "FLASHINFER" and is_hip():
pytest.skip("Flashinfer does not support ROCm/HIP.")

os.environ["VLLM_ATTENTION_BACKEND"] = backend

with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)

Expand Down
47 changes: 20 additions & 27 deletions vllm/model_executor/models/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from vllm.config import CacheConfig
from vllm.distributed.parallel_state import (
get_pp_group, get_tensor_model_parallel_world_size)
from vllm.distributed.utils import get_pp_indices
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
Expand All @@ -42,6 +41,8 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput

from .utils import is_pp_missing_parameter, make_layers


class GPT2Attention(nn.Module):

Expand Down Expand Up @@ -183,18 +184,9 @@ def __init__(
self.embed_dim = config.hidden_size
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.start_layer, self.end_layer = get_pp_indices(
self.start_layer, self.end_layer, self.h = make_layers(
config.num_hidden_layers,
get_pp_group().rank_in_group,
get_pp_group().world_size)
self.h = nn.ModuleList(
[nn.Identity() for _ in range(self.start_layer)] + [
GPT2Block(config, cache_config, quant_config)
for _ in range(self.start_layer, self.end_layer)
] + [
nn.Identity()
for _ in range(self.end_layer, config.num_hidden_layers)
])
lambda: GPT2Block(config, cache_config, quant_config))
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

def forward(
Expand Down Expand Up @@ -291,19 +283,20 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
continue
if not name.startswith("transformer."):
name = "transformer." + name
try:
param = params_dict[name]
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
# Because of this, we need to transpose the weights.
# Note(zhuohan): the logic below might break quantized models.
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
if conv1d_weight_name not in name:
continue
if not name.endswith(".weight"):
continue
loaded_weight = loaded_weight.t()
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
except KeyError:

if is_pp_missing_parameter(name, self):
continue

param = params_dict[name]
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
# Because of this, we need to transpose the weights.
# Note(zhuohan): the logic below might break quantized models.
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
if conv1d_weight_name not in name:
continue
if not name.endswith(".weight"):
continue
loaded_weight = loaded_weight.t()
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
50 changes: 22 additions & 28 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@

from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_pp_group, get_pp_indices,
get_tensor_model_parallel_rank,
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
Expand All @@ -51,6 +50,7 @@
from vllm.utils import is_hip, print_warning_once

from .interfaces import SupportsLoRA
from .utils import is_pp_missing_parameter, make_layers


class LlamaMLP(nn.Module):
Expand Down Expand Up @@ -262,20 +262,11 @@ def __init__(
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
self.start_layer, self.end_layer = get_pp_indices(
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
get_pp_group().rank_in_group,
get_pp_group().world_size)
self.layers = nn.ModuleList(
[nn.Identity() for _ in range(self.start_layer)] + [
LlamaDecoderLayer(config=config,
cache_config=cache_config,
quant_config=quant_config)
for _ in range(self.start_layer, self.end_layer)
] + [
nn.Identity()
for _ in range(self.end_layer, config.num_hidden_layers)
])
lambda: LlamaDecoderLayer(config=config,
cache_config=cache_config,
quant_config=quant_config))
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -455,12 +446,14 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
try:
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
except KeyError:
pass

if is_pp_missing_parameter(name, self):
continue

param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)

break
else:
# Skip loading extra bias for GPTQ models.
Expand All @@ -479,13 +472,14 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
continue
else:
name = remapped_kv_scale_name
try:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
except KeyError:
pass

if is_pp_missing_parameter(name, self):
continue

param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
Expand Down
56 changes: 56 additions & 0 deletions vllm/model_executor/models/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Callable, Dict, List, Tuple

import torch

from vllm.multimodal import BatchedTensors
Expand Down Expand Up @@ -39,3 +41,57 @@ def merge_vision_embeddings(input_ids: torch.Tensor,
inputs_embeds[mask] = torch.cat(vision_embeddings)

return inputs_embeds


class PPMissingLayer(torch.nn.Identity):
"""
A placeholder layer for missing layers in a pipeline parallel model.
"""

def __init__(self, *args, **kwargs):
super().__init__()


def make_layers(
num_hidden_layers: int, layer_fn: Callable[[], torch.nn.Module]
) -> Tuple[int, int, torch.nn.ModuleList]:
"""Make a list of layers with the given layer function, taking
pipeline parallelism into account.
"""
from vllm.distributed.parallel_state import get_pp_group
from vllm.distributed.utils import get_pp_indices
start_layer, end_layer = get_pp_indices(num_hidden_layers,
get_pp_group().rank_in_group,
get_pp_group().world_size)
modules = torch.nn.ModuleList(
[PPMissingLayer() for _ in range(start_layer)] +
[layer_fn() for _ in range(start_layer, end_layer)] +
[PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
return start_layer, end_layer, modules


# NOTE: don't use lru_cache here because it can prevent garbage collection
_model_to_pp_missing_layer_names: Dict[int, List[str]] = {}


def get_pp_missing_layer_names(model: torch.nn.Module) -> List[str]:
"""Get the names of the missing layers in a pipeline parallel model."""
model_id = id(model)
if model_id in _model_to_pp_missing_layer_names:
return _model_to_pp_missing_layer_names[model_id]

missing_layer_names = []
for name, module in model.named_modules():
if isinstance(module, PPMissingLayer):
missing_layer_names.append(name)
_model_to_pp_missing_layer_names[model_id] = missing_layer_names

return missing_layer_names


def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
"""Check if a parameter is missing in a pipeline parallel model."""
for missing_layer_name in get_pp_missing_layer_names(model):
if name.startswith(missing_layer_name):
return True
return False

0 comments on commit 09a8d63

Please sign in to comment.