Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature][Spec Decode] Simplify the use of Eagle Spec Decode #12304

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1733,6 +1733,15 @@ def maybe_create_spec_config(

draft_hf_config = draft_model_config.hf_config

# Detect EAGLE prefix to replace hf_config for EAGLE draft_model
if "eagle-" in draft_model_config.model.lower():
from vllm.transformers_utils.configs.eagle import EAGLEConfig
if isinstance(draft_model_config.hf_config, EAGLEConfig):
pass
else:
eagle_config = EAGLEConfig(draft_model_config.hf_config)
draft_model_config.hf_config = eagle_config

if (num_speculative_tokens is not None
and hasattr(draft_hf_config, "num_lookahead_tokens")):
draft_hf_config.num_lookahead_tokens = num_speculative_tokens
Expand Down
15 changes: 11 additions & 4 deletions vllm/model_executor/models/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,12 +198,19 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
else:
model_weights[f"model.{name}"] = loaded_weight

lm_head_weight = model_weights.pop("lm_head.weight")
if "lm_head.weight" in model_weights:
lm_head_weight = model_weights.pop("lm_head.weight")

if self.token_map is not None and\
lm_head_weight.shape[0] > self.token_map.shape[0]:
if self.token_map is not None and\
lm_head_weight.shape[0] > self.token_map.shape[0]:

lm_head_weight = lm_head_weight[self.token_map]
lm_head_weight = lm_head_weight[self.token_map]

else:
lm_head_weight = torch.tensor(
0,
dtype=torch.float32,
).expand(self.lm_head.org_vocab_size, self.lm_head.embedding_dim)

weight_loader = getattr(self.lm_head.weight, "weight_loader",
default_weight_loader)
Expand Down
12 changes: 12 additions & 0 deletions vllm/spec_decode/multi_step_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch

from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.platforms import current_platform
from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData,
SequenceGroupMetadata)
Expand Down Expand Up @@ -384,3 +385,14 @@ def _raise_if_unsupported(
execute_model_req.seq_group_metadata_list):
raise NotImplementedError(
"MultiStepWorker does not support beam search.")

def maybe_load_lm_head_weight(
self,
lm_head_weight: torch.Tensor,
) -> None:
weight_loader = getattr(
self.worker.model_runner.model_runner.model.lm_head.weight,
"weight_loader", default_weight_loader)
weight_loader(
self.worker.model_runner.model_runner.model.lm_head.weight,
lm_head_weight)
19 changes: 19 additions & 0 deletions vllm/spec_decode/smaller_tp_proposer_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
patch_tensor_parallel_group)
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.multi_step_worker import MultiStepWorker
Expand Down Expand Up @@ -171,3 +172,21 @@ def get_cache_block_size_bytes(self) -> int:
@property
def vocab_size(self) -> int:
return self._worker.vocab_size

def maybe_load_lm_head_weight(
self,
lm_head_weight: torch.Tensor,
) -> None:
if self._is_dummy:
return

with self._patch_tensor_parallel_group():
weight_loader = getattr(
self._worker.worker.model_runner.model_runner.model.\
lm_head.weight,
"weight_loader",
default_weight_loader)
weight_loader(
self._worker.worker.model_runner.model_runner.model.\
lm_head.weight,
lm_head_weight)
27 changes: 25 additions & 2 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import torch.nn as nn

from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig
from vllm.distributed.communication_op import broadcast_tensor_dict
from vllm.distributed.communication_op import (broadcast_tensor_dict,
tensor_model_parallel_gather)
from vllm.logger import init_logger
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.model_executor.layers.sampler import SamplerOutput
Expand Down Expand Up @@ -153,6 +154,7 @@ def create_worker(
) -> "SpecDecodeWorker":

allow_zero_draft_token_step = True
enable_lm_head_weight_load = False
ngram_prompt_lookup_max = (
draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
ngram_prompt_lookup_min = (
Expand Down Expand Up @@ -185,6 +187,11 @@ def create_worker(
"EAGLE does not support TP > 1 yet")

allow_zero_draft_token_step = False

# Load lm_head weight for eagle in init_device
if draft_model_config.hf_config.model_type == "eagle":
enable_lm_head_weight_load = True

proposer_worker = MultiStepWorker(**draft_worker_kwargs)

proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
Expand Down Expand Up @@ -237,7 +244,8 @@ def create_worker(
disable_log_stats=disable_log_stats,
disable_by_batch_size=disable_by_batch_size,
spec_decode_sampler=spec_decode_sampler,
allow_zero_draft_token_step=allow_zero_draft_token_step)
allow_zero_draft_token_step=allow_zero_draft_token_step,
enable_lm_head_weight_load=enable_lm_head_weight_load)

def __init__(
self,
Expand All @@ -250,6 +258,7 @@ def __init__(
metrics_collector: Optional[AsyncMetricsCollector] = None,
disable_by_batch_size: Optional[int] = None,
allow_zero_draft_token_step: Optional[bool] = True,
enable_lm_head_weight_load: Optional[bool] = False,
):
"""
Create a SpecDecodeWorker.
Expand Down Expand Up @@ -280,6 +289,8 @@ def __init__(
allow_zero_draft_token_step: whether to allow a step where the draft
model generates no draft token; should disallow when the tp of
draft model is larger than 1 (TODO: #5814)
enable_lm_head_weight_load: whether to load lm_head weight for
draft models like eagle.
"""
self.proposer_worker = proposer_worker
self.scorer_worker = scorer_worker
Expand All @@ -289,6 +300,7 @@ def __init__(
self.disable_by_batch_size = disable_by_batch_size or float("inf")
self.spec_decode_sampler = spec_decode_sampler
self._allow_zero_draft_token_step = allow_zero_draft_token_step
self._enable_lm_head_weight_load = enable_lm_head_weight_load
self._metrics = AsyncMetricsCollector(
self.spec_decode_sampler
) if metrics_collector is None else metrics_collector
Expand Down Expand Up @@ -325,6 +337,17 @@ def init_device(self) -> None:
self.scorer_worker.load_model()
self.proposer_worker.load_model()

if self._enable_lm_head_weight_load:
# NOTE(Shangming): gather lm_head weight when tp enabled
target_lm_head_weight: torch.Tensor = tensor_model_parallel_gather(
self.scorer_worker.model_runner.model_runner.model.lm_head.\
weight.data,
dim=0,
)

self.proposer_worker.maybe_load_lm_head_weight(
target_lm_head_weight)

self._metrics.init_tensors(self.rank, device_type=self.device)
self.spec_decode_sampler.init_tensors(self.rank,
device_type=self.device)
Expand Down
Loading