Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
Signed-off-by: Roger Wang <[email protected]>
  • Loading branch information
ywang96 committed Jan 21, 2025
1 parent 2fc6944 commit 7f2f795
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 14 deletions.
46 changes: 45 additions & 1 deletion vllm/multimodal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

if TYPE_CHECKING:
from .hasher import MultiModalHashDict
from .inputs import MultiModalPlaceholderDict
from .inputs import MultiModalKwargs, MultiModalPlaceholderDict


class MediaConnector:
Expand Down Expand Up @@ -477,3 +477,47 @@ def merge_and_sort_multimodal_metadata(
merged_hashes = None

return sorted_modalities, merged_placeholders, merged_hashes


def group_mm_inputs_by_modality(
mm_inputs: list["MultiModalKwargs"]) -> list[list["MultiModalKwargs"]]:
"""Group consecutive MultiModalKwargs from mm_inputs with the same modality
together into the same list for batching purpose. For MultiModalKwargs with
multiple modalities, put them into their own list.
Args:
mm_inputs: List of MultiModalKwargs.
Returns:
list[list[MultiModalKwargs]]: List of list of MultiModalKwargs, each
inner list contains consecutive MultiModalKwargs with same modality, or
one with multimodal modalities.
"""
if not mm_inputs:
return []

grouped_mm_inputs = []
current_group = [mm_inputs[0]]

for mm_input in mm_inputs[1:]:
# If the current input has multiple modalities, finalize the current
# group and start a new standalone group for this input.
if len(mm_input.modalities) > 1:
grouped_mm_inputs.append(current_group)
current_group = [mm_input]
else:
# If the current input has the same single modality as the previous
# one, add it to the current group.
if (len(current_group[-1].modalities) == 1
and mm_input.modalities == current_group[-1].modalities):
current_group.append(mm_input)
else:
# Otherwise, finalize the current group and start a new one.
grouped_mm_inputs.append(current_group)
current_group = [mm_input]

# Add the last group to the result.
if current_group:
grouped_mm_inputs.append(current_group)

return grouped_mm_inputs
53 changes: 40 additions & 13 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.sampling_params import SamplingType
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
LayerBlockType, cdiv, is_pin_memory_available)
Expand Down Expand Up @@ -629,19 +630,45 @@ def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
for input_id in encoder_input_ids:
mm_inputs.append(req_state.mm_inputs[input_id])
req_input_ids.append((req_id, input_id))
batched_mm_inputs = MultiModalKwargs.batch(mm_inputs)
batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs,
device=self.device)

# Run the encoder.
# `encoder_outputs` is either of the following:
# 1. A tensor of shape [num_images, feature_size, hidden_size]
# in case when feature_size is fixed across all images.
# 2. A list (length: num_images) of tensors, each of shape
# [feature_size, hidden_size] in case when the feature size is
# dynamic depending on input images.
encoder_outputs = self.model.get_multimodal_embeddings(
**batched_mm_inputs)

# Batch mm inputs as much as we can: if a request has multiple or
# a different modality than the previous one, we process it
# separately to preserve item order.
# FIXME(ywang96): This is a hacky way to deal with multiple modalities
# in the same batch while still being able to benefit from batching
# multimodal inputs. The proper solution should be reordering the
# encoder outputs.
grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs)

# If there is only one group (single modality), we can return the
# result directly.
if len(grouped_mm_inputs_list) == 1:
batched_mm_inputs = MultiModalKwargs.batch(
grouped_mm_inputs_list[0])
batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs,
device=self.device)
# Run the encoder.
# `encoder_outputs` is either of the following:
# 1. A tensor of shape [num_items, feature_size, hidden_size]
# in case when feature_size is fixed across all multimodal items.
# 2. A list of tuple (length: num_items) of tensors, each of shape
# (feature_size, hidden_size) in case when the feature size is
# dynamic depending on input multimodal items.
encoder_outputs = self.model.get_multimodal_embeddings(
**batched_mm_inputs)

# If there are multiple groups, we process them one by one
# and concatenate the results.
else:
encoder_outputs = []
for grouped_mm_inputs in grouped_mm_inputs_list:
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
batched_mm_inputs = MultiModalKwargs.as_kwargs(
batched_mm_inputs, device=self.device)
curr_group_outputs = self.model.get_multimodal_embeddings(
**batched_mm_inputs)
for output in curr_group_outputs:
encoder_outputs.append(output)

# Cache the encoder outputs.
for (req_id, input_id), output in zip(req_input_ids, encoder_outputs):
Expand Down

0 comments on commit 7f2f795

Please sign in to comment.