diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md
index 613343281464c..f74c201bdff6b 100644
--- a/docs/source/models/supported_models.md
+++ b/docs/source/models/supported_models.md
@@ -570,28 +570,28 @@ See [this page](#generative-models) for more information on how to use generativ
- `rhymes-ai/Aria`
-
- ✅︎
- -
+ - ✅︎
* - `Blip2ForConditionalGeneration`
- BLIP-2
- T + IE
- `Salesforce/blip2-opt-2.7b`, `Salesforce/blip2-opt-6.7b`, etc.
-
- ✅︎
- -
+ - ✅︎
* - `ChameleonForConditionalGeneration`
- Chameleon
- T + I
- `facebook/chameleon-7b` etc.
-
- ✅︎
- -
+ - ✅︎
* - `FuyuForCausalLM`
- Fuyu
- T + I
- `adept/fuyu-8b` etc.
-
- ✅︎
- -
+ - ✅︎
* - `ChatGLMModel`
- GLM-4V
- T + I
@@ -633,7 +633,7 @@ See [this page](#generative-models) for more information on how to use generativ
- `llava-hf/llava-v1.6-mistral-7b-hf`, `llava-hf/llava-v1.6-vicuna-7b-hf`, etc.
-
- ✅︎
- -
+ - ✅︎
* - `LlavaNextVideoForConditionalGeneration`
- LLaVA-NeXT-Video
- T + V
diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py
index 77af914a6ef02..b51bfae455267 100644
--- a/examples/offline_inference_vision_language.py
+++ b/examples/offline_inference_vision_language.py
@@ -24,10 +24,13 @@ def run_aria(question: str, modality: str):
assert modality == "image"
model_name = "rhymes-ai/Aria"
+ # NOTE: Need L40 (or equivalent) to avoid OOM
llm = LLM(model=model_name,
tokenizer_mode="slow",
- trust_remote_code=True,
dtype="bfloat16",
+ max_model_len=4096,
+ max_num_seqs=2,
+ trust_remote_code=True,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
prompt = (f"<|im_start|>user\n<|img|>\n{question}"
@@ -57,6 +60,7 @@ def run_chameleon(question: str, modality: str):
prompt = f"{question}"
llm = LLM(model="facebook/chameleon-7b",
max_model_len=4096,
+ max_num_seqs=2,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
stop_token_ids = None
return llm, prompt, stop_token_ids
@@ -257,7 +261,7 @@ def run_minicpmv(question: str, modality: str):
# 2.5
# model_name = "openbmb/MiniCPM-Llama3-V-2_5"
- #2.6
+ # 2.6
model_name = "openbmb/MiniCPM-V-2_6"
tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True)
@@ -430,9 +434,11 @@ def run_pixtral_hf(question: str, modality: str):
model_name = "mistral-community/pixtral-12b"
+ # NOTE: Need L40 (or equivalent) to avoid OOM
llm = LLM(
model=model_name,
max_model_len=8192,
+ max_num_seqs=2,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py
index 1a9c1b4ef1be0..7db08166826eb 100644
--- a/tests/models/decoder_only/vision_language/test_models.py
+++ b/tests/models/decoder_only/vision_language/test_models.py
@@ -140,10 +140,7 @@
"aria": VLMTestInfo(
models=["rhymes-ai/Aria"],
tokenizer_mode="slow",
- test_type=(
- VLMTestType.IMAGE,
- VLMTestType.MULTI_IMAGE,
- ),
+ test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
dtype="bfloat16",
prompt_formatter=lambda img_prompt: f"<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n ", # noqa: E501
img_idx_to_prompt=lambda idx: "<|img|>\n",
@@ -179,6 +176,7 @@
test_type=VLMTestType.IMAGE,
prompt_formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:",
max_model_len=4096,
+ max_num_seqs=2,
auto_cls=AutoModelForVision2Seq,
postprocess_inputs=model_utils.cast_dtype_post_processor(
"pixel_values"
@@ -201,7 +199,6 @@
vllm_output_post_proc=model_utils.fuyu_vllm_to_hf_output,
num_logprobs=10,
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
- marks=[large_gpu_mark(min_gb=48)],
),
"glm4": VLMTestInfo(
models=["THUDM/glm-4v-9b"],
diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py
index 1b2847ed0f534..81278cde264ff 100644
--- a/tests/multimodal/test_processing.py
+++ b/tests/multimodal/test_processing.py
@@ -528,7 +528,7 @@ def _rand_audio(
def _test_processing_cache_correctness(
model_id: str,
- modalities: set[str],
+ modalities: dict[str, bool],
hit_rate: float,
num_batches: int,
simplify_rate: float,
@@ -583,9 +583,8 @@ def _test_processing_cache_correctness(
partial(_rand_audio, rng, min_len=256, max_len=512, sr=16000),
}
input_max_count = {
- "image": 3,
- "video": 3,
- "audio": 3,
+ modality: 3 if supports_multi else 1
+ for modality, supports_multi in modalities.items()
}
for batch_idx in range(num_batches):
@@ -624,12 +623,16 @@ def _test_processing_cache_correctness(
# yapf: disable
@pytest.mark.parametrize(("model_id", "modalities"), [
- ("llava-hf/llava-1.5-7b-hf", {"image"}),
- ("TIGER-Lab/Mantis-8B-siglip-llama3", {"image"}),
- ("mistral-community/pixtral-12b", {"image"}),
- ("Qwen/Qwen2-VL-2B-Instruct", {"image", "video"}),
- ("Qwen/Qwen2-Audio-7B-Instruct", {"audio"}),
- ("fixie-ai/ultravox-v0_3", {"audio"}),
+ ("rhymes-ai/Aria", {"image": True}),
+ ("Salesforce/blip2-opt-2.7b", {"image": False}),
+ ("facebook/chameleon-7b", {"image": True}),
+ ("adept/fuyu-8b", {"image": False}),
+ ("llava-hf/llava-1.5-7b-hf", {"image": True}),
+ ("TIGER-Lab/Mantis-8B-siglip-llama3", {"image": True}),
+ ("mistral-community/pixtral-12b", {"image": True}),
+ ("Qwen/Qwen2-VL-2B-Instruct", {"image": True, "video": True}),
+ ("Qwen/Qwen2-Audio-7B-Instruct", {"audio": True}),
+ ("fixie-ai/ultravox-v0_3", {"audio": True}),
])
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
@pytest.mark.parametrize("num_batches", [32])
@@ -637,7 +640,7 @@ def _test_processing_cache_correctness(
# yapf: enable
def test_processing_cache_correctness(
model_id: str,
- modalities: set[str],
+ modalities: dict[str, bool],
hit_rate: float,
num_batches: int,
simplify_rate: float,
@@ -653,7 +656,7 @@ def test_processing_cache_correctness(
# yapf: disable
@pytest.mark.parametrize(("model_id", "modalities"), [
- ("microsoft/Phi-3-vision-128k-instruct", {"image"}),
+ ("microsoft/Phi-3-vision-128k-instruct", {"image": True}),
])
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
@pytest.mark.parametrize("num_batches", [32])
@@ -661,7 +664,7 @@ def test_processing_cache_correctness(
# yapf: enable
def test_processing_cache_correctness_phi3v(
model_id: str,
- modalities: set[str],
+ modalities: dict[str, bool],
hit_rate: float,
num_batches: int,
simplify_rate: float,
diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py
index 9437ad9688422..4ad6e859f4d93 100644
--- a/vllm/model_executor/models/aria.py
+++ b/vllm/model_executor/models/aria.py
@@ -1,15 +1,15 @@
-import math
-from typing import Iterable, List, Optional, Set, Tuple, TypedDict, Union
+from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
+ Union)
import torch
import torch.nn as nn
from torch.nn.init import trunc_normal_
-from transformers import LlamaConfig
+from transformers import BatchFeature, PretrainedConfig
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, QuantizationConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_rank
-from vllm.inputs import INPUT_REGISTRY, token_inputs
+from vllm.inputs import InputContext
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@@ -17,30 +17,27 @@
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
get_compressed_tensors_cache_scale)
-from vllm.model_executor.layers.sampler import (Sampler, SamplerOutput,
- SamplingMetadata)
+from vllm.model_executor.layers.sampler import (SamplerOutput,
+ SamplingMetadata, get_sampler)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
-from vllm.model_executor.models.idefics2_vision_model import (
- Idefics2VisionTransformer)
-from vllm.model_executor.models.interfaces import SupportsMultiModal
-from vllm.model_executor.models.llama import (LlamaDecoderLayer, LlamaMLP,
- LlamaModel)
-from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
- is_pp_missing_parameter,
- maybe_prefix,
- merge_multimodal_embeddings)
from vllm.multimodal import MULTIMODAL_REGISTRY
-from vllm.multimodal.image import cached_get_image_processor
-from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
-from vllm.multimodal.utils import (cached_get_tokenizer,
- repeat_and_pad_placeholder_tokens)
+from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
+ NestedTensors)
+from vllm.multimodal.processing import (BaseMultiModalProcessor,
+ MultiModalDataItems, ProcessorInputs,
+ PromptReplacement)
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.aria import (AriaMoELMConfig,
AriaVisionConfig)
-from .utils import flatten_bn
+from .idefics2_vision_model import Idefics2VisionTransformer
+from .interfaces import SupportsMultiModal
+from .llama import LlamaDecoderLayer, LlamaMLP, LlamaModel
+from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
+ is_pp_missing_parameter, maybe_prefix,
+ merge_multimodal_embeddings)
class AriaImagePixelInputs(TypedDict):
@@ -251,7 +248,7 @@ def forward(self, x, attn_mask=None):
class AriaFusedMoE(FusedMoE):
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
- shard_id: str) -> Set[str]:
+ shard_id: str) -> None:
# Override the weight_loader to handle the expert weights in the Aria
# model, which are already packed with experts, and merge the gate and
# up weights for each expert.
@@ -346,7 +343,7 @@ class MoEDecoderLayer(LlamaDecoderLayer):
def __init__(
self,
- config: LlamaConfig,
+ config: AriaMoELMConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
@@ -434,7 +431,7 @@ def load_weights(self, weights: Iterable[Tuple[str,
return loaded_params
-def build_mm_projector(config):
+def build_mm_projector(config: PretrainedConfig):
return AriaProjector(
patch_to_query_dict=config.projector_patch_to_query_dict,
embed_dim=config.vision_config.hidden_size,
@@ -445,75 +442,70 @@ def build_mm_projector(config):
)
-def get_max_multimodal_tokens(ctx):
- return max(ctx.model_config.hf_config.image_size2tokens.values())
-
-
-def input_mapper_for_aria(ctx, data):
- return MultiModalKwargs(data)
+def get_max_aria_image_tokens(ctx: InputContext):
+ hf_config = ctx.get_hf_config()
+ return max(hf_config.projector_patch_to_query_dict.values())
-def input_processor(ctx, llm_inputs):
- multi_modal_data = llm_inputs.get("multi_modal_data")
- # if it is pure text input, use it as is
- if multi_modal_data is None or "image" not in multi_modal_data:
- return llm_inputs
+class AriaMultiModalProcessor(BaseMultiModalProcessor):
- model_config = ctx.model_config
-
- tokenizer = cached_get_tokenizer(model_config.tokenizer)
- image_processor = cached_get_image_processor(
- model_config.model, trust_remote_code=model_config.trust_remote_code)
- hf_config = model_config.hf_config
+ def _get_mm_fields_config(
+ self,
+ hf_inputs: BatchFeature,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ ) -> Mapping[str, MultiModalFieldConfig]:
+ return dict(
+ pixel_values=MultiModalFieldConfig.batched("image"),
+ pixel_mask=MultiModalFieldConfig.batched("image"),
+ )
- # prepare image tokens, the max_image_size is used to determine the number
- # of patch_size for every image
- max_image_size = multi_modal_data.pop("max_image_size", 980)
- _split_image = multi_modal_data.pop("split_image", False)
+ def _get_prompt_replacements(
+ self,
+ mm_items: MultiModalDataItems,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ out_mm_kwargs: MultiModalKwargs,
+ ) -> list[PromptReplacement]:
+ hf_config = self.ctx.get_hf_config()
+ image_token_id = hf_config.image_token_index
+
+ max_image_tokens = get_max_aria_image_tokens(self.ctx)
+
+ return [
+ PromptReplacement(
+ modality="image",
+ target=[image_token_id],
+ replacement=[image_token_id] * max_image_tokens,
+ )
+ ]
- assert isinstance(max_image_size,
- (int, float)), "max_image_size should be float or int"
- images = (multi_modal_data["image"] if isinstance(
- multi_modal_data["image"], list) else [multi_modal_data["image"]])
+ def _get_dummy_mm_inputs(
+ self,
+ mm_counts: Mapping[str, int],
+ ) -> ProcessorInputs:
+ hf_config = self.ctx.get_hf_config()
+ vision_config: AriaVisionConfig = hf_config.vision_config
+
+ max_image_size = vision_config.image_size
+ num_images = mm_counts.get("image", 0)
+
+ mm_data = {
+ "image":
+ self._get_dummy_images(width=max_image_size,
+ height=max_image_size,
+ num_images=num_images)
+ }
- image_inputs = image_processor.preprocess(images,
- max_image_size=max_image_size,
- split_image=_split_image,
- return_tensors="pt").data
- image_inputs['pixel_values'] = image_inputs['pixel_values'].to(
- ctx.model_config.dtype)
- num_crops = image_inputs.pop("num_crops")
+ hf_processor = self._get_hf_processor()
+ image_token: str = hf_processor.image_token # type: ignore
- prompt_token_ids = llm_inputs["prompt_token_ids"]
- if num_crops.sum().item() > 0:
- _, prompt_token_ids, _ = repeat_and_pad_placeholder_tokens(
- tokenizer,
- None,
- prompt_token_ids,
- placeholder_token_id=hf_config.image_token_index,
- repeat_count=num_crops,
+ return ProcessorInputs(
+ prompt_text=image_token * num_images,
+ mm_data=mm_data,
)
- repeat_count = [hf_config.image_size2tokens[max_image_size]
- ] * sum(num_crops).item()
- new_prompt, new_token_ids, _ = repeat_and_pad_placeholder_tokens(
- tokenizer,
- None,
- prompt_token_ids,
- placeholder_token_id=hf_config.image_token_index,
- repeat_count=repeat_count,
- )
-
- return token_inputs(
- prompt_token_ids=new_token_ids,
- prompt=new_prompt,
- multi_modal_data={"image": image_inputs},
- )
-
-@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_multimodal_tokens)
-@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_aria)
-@INPUT_REGISTRY.register_input_processor(input_processor)
+@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_aria_image_tokens)
+@MULTIMODAL_REGISTRY.register_processor(AriaMultiModalProcessor)
class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
"""
Aria model for conditional generation tasks.
@@ -540,12 +532,6 @@ def __init__(
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
- # prepare the image_size to tokens mapping for the image preprocess, see
- # input_processor
- config.image_size2tokens = {
- int(math.sqrt(k) * config.vision_config.patch_size): v
- for k, v in config.projector_patch_to_query_dict.items()
- }
self.config = config
self.vision_tower = AriaVisionModel(config.vision_config)
self.multi_modal_projector = build_mm_projector(config)
@@ -566,7 +552,7 @@ def __init__(
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
self.vocab_size, logit_scale)
- self.sampler = Sampler()
+ self.sampler = get_sampler()
def _validate_image_sizes(
self, images: List[torch.Tensor]) -> List[torch.Tensor]:
@@ -588,7 +574,12 @@ def _parse_and_validate_image_input(
pixel_values = self._validate_image_sizes(pixel_values)
pixel_values = flatten_bn(pixel_values, concat=True)
+
if pixel_mask is not None:
+ if not isinstance(pixel_mask, (torch.Tensor, list)):
+ raise ValueError("Incorrect type of pixel mask. "
+ f"Got type: {type(pixel_mask)}")
+
pixel_mask = flatten_bn(pixel_mask, concat=True)
return AriaImagePixelInputs(
diff --git a/vllm/model_executor/models/blip.py b/vllm/model_executor/models/blip.py
index 42a239cadac46..987dfaf44f228 100644
--- a/vllm/model_executor/models/blip.py
+++ b/vllm/model_executor/models/blip.py
@@ -4,22 +4,16 @@
import torch
import torch.nn as nn
-from PIL import Image
from transformers import Blip2VisionConfig, BlipVisionConfig
from vllm.attention.layer import MultiHeadAttention
-from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
-from vllm.inputs import DecoderOnlyInputs, token_inputs
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
-from vllm.multimodal.utils import (cached_get_tokenizer,
- repeat_and_pad_placeholder_tokens)
-from vllm.sequence import SequenceData
def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
@@ -33,92 +27,6 @@ def get_blip_num_patches(*, image_size: int, patch_size: int) -> int:
return grid_length * grid_length
-def get_blip_image_feature_size(
- hf_config: Union[BlipVisionConfig, Blip2VisionConfig]) -> int:
- return get_blip_num_patches(image_size=hf_config.image_size,
- patch_size=hf_config.patch_size)
-
-
-def get_max_blip_image_tokens(
- hf_config: Union[BlipVisionConfig, Blip2VisionConfig]) -> int:
- return get_blip_image_feature_size(hf_config)
-
-
-def dummy_seq_data_for_blip(
- hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
- seq_len: int,
- num_images: int,
- *,
- image_token_id: int,
- image_feature_size_override: Optional[int] = None,
-):
- if image_feature_size_override is None:
- image_feature_size = get_blip_image_feature_size(hf_config)
- else:
- image_feature_size = image_feature_size_override
-
- return SequenceData.from_prompt_token_counts(
- (image_token_id, image_feature_size * num_images),
- (0, seq_len - image_feature_size * num_images),
- )
-
-
-def dummy_image_for_blip(
- hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
- num_images: int,
- *,
- image_width_override: Optional[int] = None,
- image_height_override: Optional[int] = None,
-):
- width = height = hf_config.image_size
- if image_width_override is not None:
- width = image_width_override
- if image_height_override is not None:
- height = image_height_override
-
- image = Image.new("RGB", (width, height), color=0)
- return {"image": image if num_images == 1 else [image] * num_images}
-
-
-def input_processor_for_blip(
- model_config: ModelConfig,
- hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
- inputs: DecoderOnlyInputs,
- *,
- image_token_id: int,
- image_feature_size_override: Optional[int] = None,
-):
- multi_modal_data = inputs.get("multi_modal_data")
- if multi_modal_data is None or "image" not in multi_modal_data:
- return inputs
-
- if "multi_modal_placeholders" in inputs and "image" in inputs[
- "multi_modal_placeholders"]:
- # The inputs already have placeholders.
- return inputs
-
- tokenizer = cached_get_tokenizer(model_config.tokenizer)
-
- if image_feature_size_override is None:
- image_feature_size = get_blip_image_feature_size(hf_config)
- else:
- image_feature_size = image_feature_size_override
-
- new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
- tokenizer,
- inputs.get("prompt"),
- inputs["prompt_token_ids"],
- placeholder_token_id=image_token_id,
- repeat_count=image_feature_size,
- )
-
- # NOTE: Create a defensive copy of the original inputs
- return token_inputs(prompt_token_ids=new_token_ids,
- prompt=new_prompt,
- multi_modal_data=multi_modal_data,
- multi_modal_placeholders={"image": ranges})
-
-
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa
class BlipVisionEmbeddings(nn.Module):
diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py
index 76b8505ee1c2a..bf70f5d904f5b 100644
--- a/vllm/model_executor/models/blip2.py
+++ b/vllm/model_executor/models/blip2.py
@@ -4,32 +4,33 @@
import torch
import torch.nn as nn
-from transformers import (Blip2Config, Blip2QFormerConfig, Blip2VisionConfig,
- apply_chunking_to_forward)
+from transformers import (BatchFeature, Blip2Config, Blip2Processor,
+ Blip2QFormerConfig, apply_chunking_to_forward)
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VllmConfig
-from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
- InputContext, token_inputs)
+from vllm.inputs import InputContext
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
-from vllm.multimodal.inputs import NestedTensors
-from vllm.multimodal.utils import consecutive_placeholder_ranges
-from vllm.sequence import IntermediateTensors, SequenceData
-
-from .blip import (BlipVisionModel, dummy_image_for_blip,
- get_max_blip_image_tokens)
+from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
+ MultiModalInputsV2, MultiModalKwargs,
+ NestedTensors, PlaceholderRange)
+from vllm.multimodal.processing import (BaseMultiModalProcessor,
+ MultiModalDataItems, ProcessorInputs,
+ PromptReplacement)
+from vllm.sequence import IntermediateTensors
+
+from .blip import BlipVisionModel
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
# We use this internally as placeholders since there is no image token
# defined on the HuggingFace repo
-BLIP2_IMAGE_TOKEN = ""
-BLIP2_IMAGE_TOKEN_ID = 50265
+_IMAGE_TOKEN_ID = 50265
class Blip2ImagePixelInputs(TypedDict):
@@ -396,92 +397,87 @@ def forward(
return sequence_output
-def get_blip2_image_feature_size(hf_config: Blip2Config) -> int:
- return hf_config.num_query_tokens
-
-
def get_max_blip2_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config(Blip2Config)
- vision_config = hf_config.vision_config
-
- if isinstance(vision_config, Blip2VisionConfig):
- return get_max_blip_image_tokens(vision_config)
-
- msg = f"Unsupported vision config: {type(vision_config)}"
- raise NotImplementedError(msg)
-
-
-def dummy_seq_data_for_blip2(
- hf_config: Blip2Config,
- seq_len: int,
- num_images: int,
- *,
- image_token_id: int,
- image_feature_size_override: Optional[int] = None,
-):
- if image_feature_size_override is None:
- image_feature_size = get_blip2_image_feature_size(hf_config)
- else:
- image_feature_size = image_feature_size_override
-
- return SequenceData.from_prompt_token_counts(
- (image_token_id, image_feature_size * num_images),
- (0, seq_len - image_feature_size * num_images),
- ), {
- "image":
- consecutive_placeholder_ranges(num_items=num_images,
- item_size=image_feature_size)
- }
-
-
-def dummy_data_for_blip2(ctx: InputContext, seq_len: int,
- mm_counts: Mapping[str, int]):
- hf_config = ctx.get_hf_config(Blip2Config)
- vision_config = hf_config.vision_config
- num_images = mm_counts["image"]
-
- seq_data, ranges = dummy_seq_data_for_blip2(
- hf_config,
- seq_len,
- num_images,
- image_token_id=BLIP2_IMAGE_TOKEN_ID,
- )
-
- if isinstance(vision_config, Blip2VisionConfig):
- mm_data = dummy_image_for_blip(vision_config, num_images)
-
- return DummyData(seq_data, mm_data, ranges)
-
- msg = f"Unsupported vision config: {type(vision_config)}"
- raise NotImplementedError(msg)
+ return hf_config.num_query_tokens
-def input_processor_for_blip2(ctx: InputContext, inputs: DecoderOnlyInputs):
- multi_modal_data = inputs.get("multi_modal_data")
- if multi_modal_data is None or "image" not in multi_modal_data:
- return inputs
+class Blip2MultiModalProcessor(BaseMultiModalProcessor):
- hf_config = ctx.get_hf_config(Blip2Config)
- image_feature_size = get_blip2_image_feature_size(hf_config)
+ def _get_hf_processor(self) -> Blip2Processor:
+ return self.ctx.get_hf_processor(Blip2Processor)
- # The original model places image tokens at the front
- # https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1514
- new_token_ids = [BLIP2_IMAGE_TOKEN_ID] * image_feature_size
- new_token_ids += inputs["prompt_token_ids"]
+ def _get_mm_fields_config(
+ self,
+ hf_inputs: BatchFeature,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ ) -> Mapping[str, MultiModalFieldConfig]:
+ return dict(
+ pixel_values=MultiModalFieldConfig.batched("image"),
+ image_embeds=MultiModalFieldConfig.batched("image"),
+ )
- new_prompt = inputs.get("prompt")
- if new_prompt is not None:
- new_prompt = BLIP2_IMAGE_TOKEN * image_feature_size + new_prompt
+ def _get_prompt_replacements(
+ self,
+ mm_items: MultiModalDataItems,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ out_mm_kwargs: MultiModalKwargs,
+ ) -> list[PromptReplacement]:
+ max_image_tokens = get_max_blip2_image_tokens(self.ctx)
+
+ return [
+ PromptReplacement(
+ modality="image",
+ target="",
+ replacement="" * max_image_tokens + "",
+ )
+ ]
- return token_inputs(prompt_token_ids=new_token_ids,
- prompt=new_prompt,
- multi_modal_data=multi_modal_data)
+ def apply(
+ self,
+ prompt_text: str,
+ mm_data: MultiModalDataDict,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ ) -> MultiModalInputsV2:
+ result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
+
+ # Only tokens should be considered as placeholders,
+ # so we ignore the trailing bos_token
+ result["mm_placeholders"] = {
+ modality: [
+ PlaceholderRange(offset=p["offset"], length=p["length"] - 1)
+ for p in ps
+ ]
+ for modality, ps in result["mm_placeholders"].items()
+ }
+
+ return result
+
+ def _get_dummy_mm_inputs(
+ self,
+ mm_counts: Mapping[str, int],
+ ) -> ProcessorInputs:
+ hf_config = self.ctx.get_hf_config(Blip2Config)
+ vision_config = hf_config.vision_config
+
+ max_image_size = vision_config.image_size
+ num_images = mm_counts.get("image", 0)
+
+ mm_data = {
+ "image":
+ self._get_dummy_images(width=max_image_size,
+ height=max_image_size,
+ num_images=num_images)
+ }
+
+ return ProcessorInputs(
+ prompt_text="",
+ mm_data=mm_data,
+ )
-@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_blip2_image_tokens)
-@INPUT_REGISTRY.register_dummy_data(dummy_data_for_blip2)
-@INPUT_REGISTRY.register_input_processor(input_processor_for_blip2)
+@MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor)
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
@@ -627,7 +623,7 @@ def get_input_embeddings(
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
- BLIP2_IMAGE_TOKEN_ID)
+ _IMAGE_TOKEN_ID)
return inputs_embeds
def forward(
diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py
index a40c321ce0a58..85fca23b05746 100644
--- a/vllm/model_executor/models/chameleon.py
+++ b/vllm/model_executor/models/chameleon.py
@@ -3,16 +3,15 @@
Tuple, TypedDict, Union)
import torch
+import torch.nn as nn
import torch.nn.functional as F
-from PIL import Image
-from torch import nn
-from transformers import ChameleonConfig, ChameleonVQVAEConfig
+from transformers import (BatchFeature, ChameleonConfig, ChameleonProcessor,
+ ChameleonVQVAEConfig)
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
-from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
- InputContext, token_inputs)
+from vllm.inputs import InputContext
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
@@ -29,11 +28,13 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.multimodal import MULTIMODAL_REGISTRY
-from vllm.multimodal.inputs import NestedTensors
-from vllm.multimodal.utils import (cached_get_tokenizer,
- consecutive_placeholder_ranges,
- repeat_and_pad_placeholder_tokens)
-from vllm.sequence import IntermediateTensors, SequenceData
+from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
+ MultiModalInputsV2, MultiModalKwargs,
+ NestedTensors, PlaceholderRange)
+from vllm.multimodal.processing import (BaseMultiModalProcessor,
+ MultiModalDataItems, ProcessorInputs,
+ PromptReplacement)
+from vllm.sequence import IntermediateTensors
from vllm.utils import print_warning_once
from .interfaces import SupportsMultiModal, SupportsPP
@@ -45,10 +46,6 @@
# and processor files, so we hardcode them in the model file for now.
CHAMELEON_CROP_SIZE_HEIGHT = CHAMELEON_CROP_SIZE_WIDTH = 512
CHAMELEON_IMAGE_SEQ_LENGTH = 1024
-CHAMELEON_IMAGE_TOKEN_ID = 8711
-CHAMELEON_IMAGE_START_TOKEN_ID = 8197
-CHAMELEON_IMAGE_END_TOKEN_ID = 8196
-CHAMELEON_SEP_TOKEN_ID = 8710
class ChameleonImagePixelInputs(TypedDict):
@@ -61,99 +58,75 @@ def get_max_chameleon_image_tokens(ctx: InputContext):
return CHAMELEON_IMAGE_SEQ_LENGTH
-def dummy_seq_data_for_chameleon(
- seq_len: int,
- num_images: int,
- *,
- image_token_id: int,
- image_feature_size_override: Optional[int] = None,
-):
- if image_feature_size_override is None:
- image_feature_size = CHAMELEON_IMAGE_SEQ_LENGTH
- else:
- image_feature_size = image_feature_size_override
-
- return SequenceData.from_prompt_token_counts(
- (image_token_id, image_feature_size * num_images),
- (0, seq_len - image_feature_size * num_images),
- ), {
- "image":
- consecutive_placeholder_ranges(num_items=num_images,
- item_size=image_feature_size)
- }
-
-
-def dummy_image_for_chameleon(
- num_images: int,
- *,
- image_width_override: Optional[int] = None,
- image_height_override: Optional[int] = None,
-):
- width = CHAMELEON_CROP_SIZE_WIDTH
- height = CHAMELEON_CROP_SIZE_HEIGHT
- if image_width_override is not None:
- width = image_width_override
- if image_height_override is not None:
- height = image_height_override
-
- image = Image.new("RGB", (width, height), color=0)
- return {"image": image if num_images == 1 else [image] * num_images}
-
-
-def dummy_data_for_chameleon(ctx: InputContext, seq_len: int,
- mm_counts: Mapping[str, int]):
- num_images = mm_counts["image"]
-
- seq_data, ranges = dummy_seq_data_for_chameleon(
- seq_len,
- num_images,
- image_token_id=CHAMELEON_IMAGE_TOKEN_ID,
- )
-
- mm_data = dummy_image_for_chameleon(num_images)
- return DummyData(seq_data, mm_data, ranges)
-
-
-def input_processor_for_chameleon(ctx: InputContext,
- inputs: DecoderOnlyInputs):
+class ChameleonMultiModalProcessor(BaseMultiModalProcessor):
- """
- Processing input prompt to insert required tokens for image placeholder.
-
- See https://github.com/huggingface/transformers/blob/0fdea8607d7e01eb0e38a1ebeb7feee30a22f0cf/src/transformers/models/chameleon/processing_chameleon.py#L58
- """ # noqa
-
- multi_modal_data = inputs.get("multi_modal_data")
- if multi_modal_data is None or "image" not in multi_modal_data:
- return inputs
-
- if "multi_modal_placeholders" in inputs and "image" in inputs[
- "multi_modal_placeholders"]:
- # The inputs already have placeholders.
- return inputs
-
- model_config = ctx.model_config
- tokenizer = cached_get_tokenizer(model_config.tokenizer)
- new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
- tokenizer,
- inputs.get("prompt"),
- inputs["prompt_token_ids"],
- placeholder_token_id=CHAMELEON_IMAGE_TOKEN_ID,
- repeat_count=CHAMELEON_IMAGE_SEQ_LENGTH,
- pad_token_left=CHAMELEON_IMAGE_START_TOKEN_ID,
- pad_token_right=CHAMELEON_IMAGE_END_TOKEN_ID,
- )
-
- # Appending sep token for chat mode to follow default processor
- # behavior
- if new_prompt is not None:
- new_prompt += tokenizer.sep_token
- new_token_ids += [CHAMELEON_SEP_TOKEN_ID]
-
- # NOTE: Create a defensive copy of the original inputs
- return token_inputs(prompt_token_ids=new_token_ids,
- prompt=new_prompt,
- multi_modal_data=multi_modal_data)
+ def _get_hf_processor(self) -> ChameleonProcessor:
+ return self.ctx.get_hf_processor(ChameleonProcessor)
+
+ def _get_mm_fields_config(
+ self,
+ hf_inputs: BatchFeature,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ ) -> Mapping[str, MultiModalFieldConfig]:
+ return dict(pixel_values=MultiModalFieldConfig.batched("image"))
+
+ def _get_prompt_replacements(
+ self,
+ mm_items: MultiModalDataItems,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ out_mm_kwargs: MultiModalKwargs,
+ ) -> list[PromptReplacement]:
+ processor = self._get_hf_processor()
+
+ return [
+ PromptReplacement(
+ modality="image",
+ target="",
+ replacement="".join([
+ processor.image_start_token,
+ processor.image_token * CHAMELEON_IMAGE_SEQ_LENGTH,
+ processor.image_end_token,
+ ]),
+ )
+ ]
+
+ def _get_dummy_mm_inputs(
+ self,
+ mm_counts: Mapping[str, int],
+ ) -> ProcessorInputs:
+ num_images = mm_counts.get("image", 0)
+
+ mm_data = {
+ "image":
+ self._get_dummy_images(width=CHAMELEON_CROP_SIZE_WIDTH,
+ height=CHAMELEON_CROP_SIZE_HEIGHT,
+ num_images=num_images)
+ }
+
+ return ProcessorInputs(
+ prompt_text="" * num_images,
+ mm_data=mm_data,
+ )
+
+ def apply(
+ self,
+ prompt_text: str,
+ mm_data: MultiModalDataDict,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ ) -> MultiModalInputsV2:
+ result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
+
+ # Only tokens should be considered as placeholders,
+ # so we ignore the image_start_token and image_end_token
+ result["mm_placeholders"] = {
+ modality: [
+ PlaceholderRange(offset=p["offset"] + 1,
+ length=p["length"] - 2) for p in ps
+ ]
+ for modality, ps in result["mm_placeholders"].items()
+ }
+
+ return result
class ChameleonLayerNorm(nn.LayerNorm):
@@ -736,7 +709,7 @@ def forward(self, pixel_values: torch.Tensor):
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
hidden_state = self.down[i_level].block[i_block](
- hidden_states[-1], )
+ hidden_states[-1])
if len(self.down[i_level].attn) > 0:
hidden_state = self.down[i_level].attn[i_block](
hidden_state)
@@ -925,10 +898,8 @@ def forward(
return hidden_states
-@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_chameleon_image_tokens)
-@INPUT_REGISTRY.register_dummy_data(dummy_data_for_chameleon)
-@INPUT_REGISTRY.register_input_processor(input_processor_for_chameleon)
+@MULTIMODAL_REGISTRY.register_processor(ChameleonMultiModalProcessor)
class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py
index 6e86900326c4b..8c14866f20b92 100644
--- a/vllm/model_executor/models/fuyu.py
+++ b/vllm/model_executor/models/fuyu.py
@@ -15,32 +15,30 @@
# limitations under the License.
""" PyTorch Fuyu model."""
import math
-from array import array
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict)
import torch
import torch.nn as nn
-import torch.utils.checkpoint
-from PIL import Image
-from transformers import FuyuImageProcessor
+from transformers import (BatchFeature, FuyuConfig, FuyuImageProcessor,
+ FuyuProcessor)
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
-from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
- InputContext, token_inputs)
+from vllm.inputs import InputContext
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.models.persimmon import PersimmonForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata
-from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
-from vllm.multimodal.image import cached_get_image_processor
-from vllm.multimodal.inputs import NestedTensors
-from vllm.multimodal.utils import (cached_get_tokenizer,
- consecutive_placeholder_ranges)
-from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
- SequenceData)
-from vllm.utils import is_list_of
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
+ MultiModalInputsV2, MultiModalKwargs,
+ NestedTensors, PlaceholderRange)
+from vllm.multimodal.parse import ImageProcessorItems
+from vllm.multimodal.processing import (BaseMultiModalProcessor,
+ MultiModalDataItems, ProcessorInputs,
+ PromptReplacement)
+from vllm.sequence import IntermediateTensors
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
@@ -54,178 +52,193 @@
MAX_IMAGE_FEATURE_SIZE_WIDTH = 1920
-class FuyuImagePixelInputs(TypedDict):
- type: Literal["pixel_values"]
+class FuyuImagePatchInputs(TypedDict):
+ type: Literal["image_patches"]
data: torch.Tensor
"""
Shape:
- (batch_size, num_patches, patch_size_x * patch_size_y * num_channels)
+ `(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
+ """
+
+ patches_per_image: List[int]
+ """
+ List of number of total patches for each image in the batch.
+ This is used to restore the first two dimensions of `data`.
"""
-def _calculate_num_image_tokens(
- height: int,
- width: int,
+def _get_fuyu_num_image_tokens(
+ image_height: int,
+ image_width: int,
) -> Tuple[int, int]:
"""
- calculate number of image tokens needed for a given image size
- The expected Fuyu image prompts is in format:
- (image_token * ncols + newline_token) * nrows
- args:
- image_size: Tuple[int, int] - (width, height) of the image
- returns:
- ncols: int - number of image tokens in x direction
- nrows: int - number of image tokens in y direction
- """
- ncol = math.ceil(width / 30)
- nrow = math.ceil(height / 30)
- return ncol, nrow
+ Calculate the number of image tokens needed for a given image size.
+ The expected Fuyu image prompts can be expressed as:
-def get_max_fuyu_image_feature_size():
+ .. code-block::
+ (image_token * ncols + newline_token) * nrows
- return _calculate_num_image_tokens(
- height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
- width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
- )
+ Args:
+ image_size: Tuple[int, int] - `(width, height)` of the image
+
+ Returns:
+ ncols: int - number of image tokens in `x` direction
+ nrows: int - number of image tokens in `y` direction
+ """
+ ncols = math.ceil(image_width / 30)
+ nrows = math.ceil(image_height / 30)
+ return ncols, nrows
def get_max_fuyu_image_tokens(ctx: InputContext):
- ncol, nrow = get_max_fuyu_image_feature_size()
- return (ncol + 1) * nrow
-
-
-def dummy_seq_data_for_fuyu(ctx: InputContext, seq_len: int, num_images: int):
- ncol, nrow = get_max_fuyu_image_feature_size()
- image_feature_size = get_max_fuyu_image_tokens(ctx)
-
- image_token_ids = (
- array(VLLM_TOKEN_ID_ARRAY_TYPE, [_IMAGE_TOKEN_ID]) * ncol +
- array(VLLM_TOKEN_ID_ARRAY_TYPE, [_NEWLINE_TOKEN_ID])) * nrow
- token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, image_token_ids) * num_images
- token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
- [0]) * (seq_len - image_feature_size * num_images)
- return SequenceData(token_ids), {
- "image":
- consecutive_placeholder_ranges(num_items=num_images,
- item_size=image_feature_size)
- }
-
-
-def dummy_image_for_fuyu(
- num_images: int,
- *,
- image_width: int,
- image_height: int,
-):
- image = Image.new("RGB", (image_width, image_height), color=0)
- return {"image": image if num_images == 1 else [image] * num_images}
-
-
-def dummy_data_for_fuyu(ctx: InputContext, seq_len: int,
- mm_counts: Mapping[str, int]):
- num_images = mm_counts["image"]
- seq_data, ranges = dummy_seq_data_for_fuyu(ctx, seq_len, num_images)
- mm_data = dummy_image_for_fuyu(num_images,
- image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
- image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT)
- return DummyData(seq_data, mm_data, ranges)
-
-
-def _fuyu_image_preprocess(image_processor: FuyuImageProcessor,
- data: List[Image.Image]):
- image_encoding = image_processor.preprocess(data, return_tensors="pt")
- batch_images = torch.stack([img[0] for img in image_encoding["images"]
- ]).unsqueeze(1)
- image_unpadded_heights = torch.tensor(
- image_encoding["image_unpadded_heights"])
- image_unpadded_widths = torch.tensor(
- image_encoding["image_unpadded_widths"])
-
- batch_size = len(image_encoding["images"])
- image_present = torch.ones(batch_size, 1, 1)
- model_image_input = image_processor.preprocess_with_tokenizer_info(
- image_input=batch_images,
- image_present=image_present,
- image_unpadded_h=image_unpadded_heights,
- image_unpadded_w=image_unpadded_widths,
- image_placeholder_id=_IMAGE_TOKEN_ID,
- image_newline_id=_NEWLINE_TOKEN_ID,
- variable_sized=True,
+ ncols, nrows = _get_fuyu_num_image_tokens(
+ image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
+ image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
)
- return model_image_input
-
-
-def input_processor_for_fuyu(ctx: InputContext, inputs: DecoderOnlyInputs):
- multi_modal_data = inputs.get("multi_modal_data")
- if multi_modal_data is None or "image" not in multi_modal_data:
- return inputs
-
- model_config = ctx.model_config
- image_data = multi_modal_data["image"]
- new_multi_modal_data = {}
- image_list = image_data if isinstance(image_data, list) else [image_data]
-
- # process image data
- if is_list_of(image_list, Image.Image):
- # Fuyu's image_processor can also finish token padding
- image_processor: FuyuImageProcessor = cached_get_image_processor(
- model_config.model)
-
- model_image_input = _fuyu_image_preprocess(image_processor, image_data)
- image_patches = torch.cat([
- image_patch[0]
- for image_patch in model_image_input["image_patches"]
- ])
- new_multi_modal_data["image"] = image_patches
-
- elif is_list_of(image_list, torch.Tensor):
- raise NotImplementedError("Embeddings input is not supported yet")
- else:
- raise TypeError(f"Invalid image type: {type(image_data)}")
-
- # process prompts
- prompt = inputs.get("prompt")
- prompt_token_ids = inputs["prompt_token_ids"]
- tokenizer = cached_get_tokenizer(model_config.model)
- # dim0 is batch_size, dim1 is subseq_size which will always be 1
- image_input_ids: List[List[
- torch.Tensor]] = model_image_input["image_input_ids"]
- image_input_ids = image_input_ids[0][0].tolist()
- bos_token = tokenizer.encode("", add_special_tokens=False)[1:]
- boa_token = tokenizer.encode("\x04", add_special_tokens=False)[1:]
-
- new_prompt = prompt + "\x04"
- new_prompt_token_ids = image_input_ids + bos_token + prompt_token_ids[
- 1:] + boa_token
-
- return token_inputs(prompt=new_prompt,
- prompt_token_ids=new_prompt_token_ids,
- multi_modal_data=new_multi_modal_data)
-
-
-def input_mapper_for_fuyu(ctx: InputContext, data: object):
- model_config = ctx.model_config
- data_list = data if isinstance(data, list) else [data]
- if is_list_of(data_list, Image.Image):
- # Fuyu's image_processor can also finish token padding
- image_processor: FuyuImageProcessor = cached_get_image_processor(
- model_config.model)
-
- model_image_input = _fuyu_image_preprocess(image_processor, data_list)
- data = torch.stack([
- image_patch[0]
- for image_patch in model_image_input["image_patches"]
- ])
-
- # image has been processed with prompt in input processor
- return MultiModalKwargs({"pixel_values": data})
-
-
-@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_fuyu)
+
+ return (ncols + 1) * nrows
+
+
+class FuyuMultiModalProcessor(BaseMultiModalProcessor):
+
+ def _get_hf_processor(self) -> FuyuProcessor:
+ return self.ctx.get_hf_processor(FuyuProcessor)
+
+ def _call_hf_processor(
+ self,
+ prompt: str,
+ mm_data: Mapping[str, object],
+ mm_kwargs: Mapping[str, object],
+ ) -> BatchFeature:
+
+ if not mm_data:
+ # Avoid warning from HF logger for text-only input
+ # Input_ids format: bos_token_id + prompt_token_ids + boa_token_id
+ # Tokenizer won't add boa_token_id by default, we add it manually.
+ tokenizer = self._get_tokenizer()
+ boa_token_id: int = tokenizer.vocab["<0x04>"] # type: ignore
+ prompt_ids = tokenizer.encode(prompt) + [boa_token_id]
+ return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
+
+ processed_outputs = super()._call_hf_processor(
+ prompt=prompt,
+ mm_data=mm_data,
+ mm_kwargs=mm_kwargs,
+ )
+
+ image_patches = processed_outputs.get("image_patches")
+ if image_patches is not None:
+ images = mm_data["images"]
+ assert isinstance(images, list)
+
+ # Original output: (1, num_images, Pn, Px * Py * C)
+ # New output: (num_images, Pn, Px * Py * C)
+ assert (isinstance(image_patches, list)
+ and len(image_patches) == 1)
+ assert (isinstance(image_patches[0], torch.Tensor)
+ and len(image_patches[0]) == len(images))
+
+ processed_outputs["image_patches"] = image_patches[0]
+
+ return processed_outputs
+
+ def _get_mm_fields_config(
+ self,
+ hf_inputs: BatchFeature,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ ) -> Mapping[str, MultiModalFieldConfig]:
+ return dict(image_patches=MultiModalFieldConfig.batched("image"))
+
+ def _get_prompt_replacements(
+ self,
+ mm_items: MultiModalDataItems,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ out_mm_kwargs: MultiModalKwargs,
+ ) -> list[PromptReplacement]:
+ hf_config = self.ctx.get_hf_config(FuyuConfig)
+ bos_token_id = hf_config.bos_token_id
+
+ tokenizer = self._get_tokenizer()
+ eot_token_id = tokenizer.bos_token_id
+ assert isinstance(eot_token_id, int)
+
+ hf_processor = self._get_hf_processor()
+ image_processor: FuyuImageProcessor = hf_processor.image_processor
+ target_size = image_processor.size
+ target_height, target_width = (target_size["height"],
+ target_size["width"])
+
+ def get_replacement_fuyu(item_idx: int):
+ images = mm_items.get_items("image", ImageProcessorItems)
+ image_size = images.get_image_size(item_idx)
+ width, height = image_size.width, image_size.height
+ if not (width <= target_width and height <= target_height):
+ height_scale_factor = target_height / height
+ width_scale_factor = target_width / width
+ optimal_scale_factor = min(height_scale_factor,
+ width_scale_factor)
+
+ height = int(height * optimal_scale_factor)
+ width = int(width * optimal_scale_factor)
+
+ ncols, nrows = _get_fuyu_num_image_tokens(
+ image_width=width,
+ image_height=height,
+ )
+
+ return (([_IMAGE_TOKEN_ID] * ncols + [_NEWLINE_TOKEN_ID]) * nrows +
+ [bos_token_id])
+
+ return [
+ PromptReplacement(
+ modality="image",
+ target=[eot_token_id],
+ replacement=get_replacement_fuyu,
+ )
+ ]
+
+ def apply(
+ self,
+ prompt_text: str,
+ mm_data: MultiModalDataDict,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ ) -> MultiModalInputsV2:
+ result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
+
+ # Only |SPEAKER| (image) tokens should be considered as placeholders,
+ # so we ignore the trailing bos_token_id
+ result["mm_placeholders"] = {
+ modality: [
+ PlaceholderRange(offset=p["offset"], length=p["length"] - 1)
+ for p in ps
+ ]
+ for modality, ps in result["mm_placeholders"].items()
+ }
+
+ return result
+
+ def _get_dummy_mm_inputs(
+ self,
+ mm_counts: Mapping[str, int],
+ ) -> ProcessorInputs:
+ num_images = mm_counts.get("image", 0)
+
+ mm_data = {
+ "image":
+ self._get_dummy_images(width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
+ height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
+ num_images=num_images)
+ }
+
+ return ProcessorInputs(
+ prompt_text="",
+ mm_data=mm_data,
+ )
+
+
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_fuyu_image_tokens)
-@INPUT_REGISTRY.register_dummy_data(dummy_data_for_fuyu)
-@INPUT_REGISTRY.register_input_processor(input_processor_for_fuyu)
+@MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor)
class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
@@ -280,28 +293,32 @@ def _validate_shape(d: torch.Tensor):
return data.to(self.vision_embed_tokens.weight.dtype)
def _parse_and_validate_image_input(
- self, **kwargs: object) -> Optional[FuyuImagePixelInputs]:
- pixel_values = kwargs.pop("pixel_values", None)
-
- if pixel_values is not None:
- if not isinstance(pixel_values, (torch.Tensor, list)):
+ self, **kwargs: object) -> Optional[FuyuImagePatchInputs]:
+ image_patches = kwargs.pop("image_patches", None)
+ if image_patches is not None:
+ if not isinstance(image_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of image patches. "
- f"Got type: {type(pixel_values)}")
+ f"Got type: {type(image_patches)}")
- return FuyuImagePixelInputs(
- type="pixel_values",
+ image_patches_flat = flatten_bn(image_patches)
+
+ return FuyuImagePatchInputs(
+ type="image_patches",
data=self._validate_pixel_values(
- flatten_bn(pixel_values, concat=True)),
+ flatten_bn(image_patches_flat, concat=True)),
+ patches_per_image=[x.size(0) for x in image_patches_flat],
)
return None
def _process_image_input(
- self, image_input: FuyuImagePixelInputs) -> torch.Tensor:
+ self, image_input: FuyuImagePatchInputs) -> NestedTensors:
+ image_patches = image_input["data"]
+ patches_per_image = image_input["patches_per_image"]
assert self.vision_embed_tokens is not None
- vision_embeddings, _ = self.vision_embed_tokens(image_input["data"])
- return vision_embeddings
+ vision_embeddings, _ = self.vision_embed_tokens(image_patches)
+ return vision_embeddings.split(patches_per_image, dim=0)
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input = self._parse_and_validate_image_input(**kwargs)
diff --git a/vllm/model_executor/models/idefics2_vision_model.py b/vllm/model_executor/models/idefics2_vision_model.py
index e430a158d869a..4e42a4b6f9e64 100644
--- a/vllm/model_executor/models/idefics2_vision_model.py
+++ b/vllm/model_executor/models/idefics2_vision_model.py
@@ -69,7 +69,8 @@ def forward(self,
patch_attention_mask: torch.BoolTensor,
tgt_sizes: Optional[torch.IntTensor] = None) -> torch.Tensor:
batch_size, _, max_im_h, max_im_w = pixel_values.shape
- patch_embeds = self.patch_embedding(pixel_values)
+ target_dtype = self.patch_embedding.weight.dtype
+ patch_embeds = self.patch_embedding(pixel_values.to(target_dtype))
embeddings = patch_embeds.flatten(2).transpose(1, 2)
max_nb_patches_h, max_nb_patches_w = (
max_im_h // self.patch_size,
@@ -309,7 +310,8 @@ def forward(
hidden_states = self.embeddings(
pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask,
- tgt_sizes=tgt_sizes)
+ tgt_sizes=tgt_sizes,
+ )
encoder_outputs = self.encoder(hidden_states)
last_hidden_state = self.post_layernorm(encoder_outputs)
return last_hidden_state
diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py
index 1d6ee2a0be72e..34dc7fa31ce6f 100644
--- a/vllm/model_executor/models/llava.py
+++ b/vllm/model_executor/models/llava.py
@@ -144,8 +144,8 @@ def _call_hf_processor(
# Original output: (1, num_images, C, H, W)
# New output: (num_images, C, H, W)
assert (isinstance(pixel_values, list)
- and len(pixel_values) == 1
- and isinstance(pixel_values[0], list)
+ and len(pixel_values) == 1)
+ assert (isinstance(pixel_values[0], list)
and len(pixel_values[0]) == len(images))
processed_outputs["pixel_values"] = pixel_values[0]
diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py
index a39f2f4124d05..5e70c11363c83 100644
--- a/vllm/model_executor/models/llava_next.py
+++ b/vllm/model_executor/models/llava_next.py
@@ -528,10 +528,8 @@ def _process_image_pixels(
stacked_image_features = self._image_pixels_to_features(
self.vision_tower, stacked_pixel_values)
- return [
- self.multi_modal_projector(image_features) for image_features in
- torch.split(stacked_image_features, num_patches_per_batch)
- ]
+ return torch.split(self.multi_modal_projector(stacked_image_features),
+ num_patches_per_batch)
def _process_image_input(
self,
diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py
index 22d29f5bbc50c..2bce13792a88d 100644
--- a/vllm/model_executor/models/pixtral.py
+++ b/vllm/model_executor/models/pixtral.py
@@ -1,8 +1,8 @@
+import math
from dataclasses import dataclass, fields
from functools import cached_property
from typing import Iterable, List, Mapping, Optional, Set, Tuple, Union
-import numpy
import torch
import torch.nn as nn
import torch.nn.functional as F
@@ -306,7 +306,7 @@ def _parse_and_validate_image_input(
images: Optional[Union[List[List[torch.Tensor]], List[torch.Tensor],
torch.Tensor]] = None,
image_tokens: Optional[torch.Tensor] = None,
- ) -> Optional[List[torch.Tensor]]:
+ ) -> Tuple[Optional[List[torch.Tensor]], Optional[torch.Tensor]]:
if images is None:
return None, None
@@ -604,11 +604,11 @@ def max_patches_per_side(self) -> int:
return self.args.image_size // self.args.patch_size
@property
- def device(self) -> torch.device:
+ def device(self) -> torch.types.Device:
return next(self.parameters()).device
@property
- def dtype(self) -> torch.device:
+ def dtype(self) -> torch.dtype:
return next(self.parameters()).dtype
@property
@@ -741,8 +741,8 @@ def get_pixtral_hf_image_feature_size(hf_config: PixtralVisionConfig,
ratio = max(image_width / max_width, image_height / max_height)
if ratio > 1:
- image_width = int(numpy.ceil(image_width / ratio))
- image_height = int(numpy.ceil(image_height / ratio))
+ image_width = int(math.ceil(image_width / ratio))
+ image_height = int(math.ceil(image_height / ratio))
num_height_tokens, num_width_tokens = _get_pixtral_hf_num_image_tokens(
(image_height, image_width),
diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py
index e3d43b017f894..de55bc6bcc123 100644
--- a/vllm/model_executor/models/qwen2_audio.py
+++ b/vllm/model_executor/models/qwen2_audio.py
@@ -23,7 +23,6 @@
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
Union)
-import numpy as np
import torch
import torch.nn as nn
from transformers import BatchFeature
@@ -177,16 +176,19 @@ def _get_dummy_mm_inputs(
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
feature_extractor = self._get_feature_extractor()
+
sampling_rate = feature_extractor.sampling_rate
audio_len = feature_extractor.chunk_length * sampling_rate
+ num_audios = mm_counts.get("audio", 0)
- audio_count = mm_counts.get("audio", 0)
- audio = np.zeros(audio_len)
- data = {"audio": [audio] * audio_count}
+ mm_data = {
+ "audio":
+ self._get_dummy_audios(length=audio_len, num_audios=num_audios)
+ }
return ProcessorInputs(
- prompt_text="<|AUDIO|>" * audio_count,
- mm_data=data,
+ prompt_text="<|AUDIO|>" * num_audios,
+ mm_data=mm_data,
)
diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py
index 6181fe3dd13d8..1e485f87bb7a4 100644
--- a/vllm/model_executor/models/qwen2_vl.py
+++ b/vllm/model_executor/models/qwen2_vl.py
@@ -29,7 +29,6 @@
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
-from PIL import Image
from transformers import BatchFeature
from transformers.models.qwen2_vl import (Qwen2VLImageProcessor,
Qwen2VLProcessor)
@@ -882,12 +881,10 @@ def _get_dummy_mm_inputs(
self,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
- num_images = mm_counts.get("image", 0)
hf_processor = self._get_hf_processor()
- image_token: str = hf_processor.image_token
image_processor = _get_image_processor(hf_processor)
- data = {}
+ image_token: str = hf_processor.image_token
resized_height, resized_width = smart_resize(
height=9999999,
width=9999999,
@@ -895,14 +892,18 @@ def _get_dummy_mm_inputs(
min_pixels=image_processor.min_pixels,
max_pixels=image_processor.max_pixels,
)
+ num_images = mm_counts.get("image", 0)
- dummy_image = Image.new("RGB", (resized_width, resized_height),
- color=0)
- data["image"] = [dummy_image] * num_images
+ mm_data = {
+ "image":
+ self._get_dummy_images(width=resized_width,
+ height=resized_height,
+ num_images=num_images)
+ }
return ProcessorInputs(
prompt_text=image_token * num_images,
- mm_data=data,
+ mm_data=mm_data,
)
diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py
index 7e853e5b90096..54be7fed3f2be 100644
--- a/vllm/model_executor/models/ultravox.py
+++ b/vllm/model_executor/models/ultravox.py
@@ -188,16 +188,19 @@ def _get_dummy_mm_inputs(
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
feature_extractor = self._get_feature_extractor()
+
sampling_rate = feature_extractor.sampling_rate
audio_len = feature_extractor.chunk_length * sampling_rate
+ num_audios = mm_counts.get("audio", 0)
- audio_count = mm_counts.get("audio", 0)
- audio = np.zeros(audio_len)
- data = {"audio": [audio] * audio_count}
+ mm_data = {
+ "audio":
+ self._get_dummy_audios(length=audio_len, num_audios=num_audios)
+ }
return ProcessorInputs(
- prompt_text="<|audio|>" * audio_count,
- mm_data=data,
+ prompt_text="<|audio|>" * num_audios,
+ mm_data=mm_data,
)
diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py
index 180489166b407..7712c3bcebe20 100644
--- a/vllm/multimodal/processing.py
+++ b/vllm/multimodal/processing.py
@@ -1,15 +1,17 @@
import pickle
import re
from abc import ABC, abstractmethod
+from collections import defaultdict
from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence
from dataclasses import dataclass, field
from functools import lru_cache
from typing import Any, NamedTuple, Optional, Protocol, TypeVar, Union
import numpy as np
+import numpy.typing as npt
import torch
from blake3 import blake3
-from PIL.Image import Image
+from PIL import Image
from transformers import BatchFeature, ProcessorMixin
from vllm.inputs import DummyData, InputProcessingContext
@@ -353,13 +355,13 @@ def _replace_matches(
) -> list[_S]:
out_seqs = list[_S]()
prev_end_idx = 0
- next_idx_by_modality = {modality: 0 for modality in mm_item_counts}
+ next_idx_by_modality = defaultdict[str, int](lambda: 0)
for match in _resolve_matches(prompt, matches):
modality = match.modality
item_idx = next_idx_by_modality[modality]
- if item_idx >= mm_item_counts[modality]:
+ if item_idx >= mm_item_counts.get(modality, 0):
continue
start_idx = match.start_idx
@@ -513,7 +515,7 @@ def _serialize_item(self, obj: object) -> bytes:
return obj.encode("utf-8")
if isinstance(obj, bytes):
return obj
- if isinstance(obj, Image):
+ if isinstance(obj, Image.Image):
return obj.tobytes()
# Convertible to NumPy arrays
@@ -673,10 +675,14 @@ def _get_prompt_replacements(
Given the original multi-modal items for this modality
and HF-processed data, output the replacements to perform.
- Note:
- Even when the HF processor already performs replacement for us,
- we still use this replacement information to determine
- the placeholder token positions for each multi-modal item.
+ Notes:
+ - You should not assume that HF processor always performs prompt
+ replacement: in :meth:`_apply_hf_processor_missing`, this method
+ is called on text-only and multimodal-only inputs separately,
+ instead of passing them in the same call.
+ - The replacement information returned by this method is also used
+ to determine the placeholder token positions for each multi-modal
+ item.
"""
raise NotImplementedError
@@ -710,6 +716,10 @@ def _call_hf_processor(
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
+ """
+ Call the HF processor on the prompt text and
+ associated multi-modal data.
+ """
return self.ctx.call_hf_processor(
self._get_hf_processor(**mm_kwargs),
dict(text=prompt, **mm_data),
@@ -723,7 +733,8 @@ def _apply_hf_processor(
hf_processor_mm_kwargs: Mapping[str, object],
) -> tuple[list[int], MultiModalKwargs]:
"""
- Apply the HF processor on the full prompt text and multi-modal data.
+ Wrapper of :meth:`_call_hf_processor` that applies
+ additional pre-processing and post-processing.
"""
processor_data, passthrough_data = self._get_hf_mm_data(mm_items)
@@ -754,10 +765,11 @@ def _apply_hf_processor_missing(
Apply the HF processor on the full prompt text, but only on the
multi-modal data that are missing from the cache.
- Note: We pass prompt text and multi-modal data into the HF processor
- in separate calls to avoid HF prompt replacement being done for
- cached items; instead, we rely on our own prompt replacement logic
- for the full text.
+ Note:
+ We pass prompt text and multi-modal data into the HF processor
+ in separate calls to avoid HF prompt replacement being done for
+ cached items; instead, we rely on our own prompt replacement logic
+ (:meth:`_get_prompt_replacements`) for the full text.
"""
mm_missing_counts = mm_missing_data_items.get_all_counts()
@@ -1010,6 +1022,36 @@ def apply(
mm_placeholders=mm_placeholders,
)
+ def _get_dummy_audios(
+ self,
+ *,
+ length: int,
+ num_audios: int,
+ ) -> list[npt.NDArray]:
+ audio = np.zeros((length, ))
+ return [audio] * num_audios
+
+ def _get_dummy_images(
+ self,
+ *,
+ width: int,
+ height: int,
+ num_images: int,
+ ) -> list[Image.Image]:
+ image = Image.new("RGB", (width, height), color=0)
+ return [image] * num_images
+
+ def _get_dummy_videos(
+ self,
+ *,
+ width: int,
+ height: int,
+ num_frames: int,
+ num_videos: int,
+ ) -> list[npt.NDArray]:
+ video = np.zeros((num_frames, width, height, 3))
+ return [video] * num_videos
+
@abstractmethod
def _get_dummy_mm_inputs(
self,
diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py
index 87b12a6fb33c1..7b6ded6a27084 100644
--- a/vllm/multimodal/utils.py
+++ b/vllm/multimodal/utils.py
@@ -400,15 +400,19 @@ def repeat_and_pad_placeholder_tokens(
placeholder_token_idx = 0
for i, token in enumerate(prompt_token_ids):
if token == placeholder_token_id:
+ curr_repeat_count = repeat_count[placeholder_token_idx]
replacement_ids = repeat_and_pad_token(
placeholder_token_id,
- repeat_count=repeat_count[placeholder_token_idx],
+ repeat_count=curr_repeat_count,
pad_token_left=pad_token_left,
pad_token_right=pad_token_right,
)
+ offset = len(new_token_ids)
+ if pad_token_left is not None:
+ offset += 1
placeholder_ranges.append({
- "offset": len(new_token_ids),
- "length": len(replacement_ids)
+ "offset": offset,
+ "length": curr_repeat_count,
})
new_token_ids.extend(replacement_ids)
placeholder_token_idx += 1
diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py
index 509771b7e2e5a..a08a86d4007dc 100644
--- a/vllm/v1/worker/gpu_model_runner.py
+++ b/vllm/v1/worker/gpu_model_runner.py
@@ -647,10 +647,23 @@ def profile_run(self) -> None:
self.mm_registry.get_max_tokens_per_item_by_modality(
self.model_config).values())
- max_num_mm_items = min(
+ max_num_mm_items_encoder_budget = min(
self.max_num_encoder_input_tokens,
self.encoder_cache_size) // max_tokens_per_mm_item
+ max_mm_items_per_req = max(
+ self.mm_registry.get_mm_limits_per_prompt(
+ self.model_config).values())
+
+ # NOTE: We do not consider max_num_batched_tokens on purpose
+ # because the multimodal embeddings can be generated in advance
+ # and chunked prefilled.
+ max_num_mm_items_decoder_budget = self.max_num_reqs * \
+ max_mm_items_per_req
+
+ max_num_mm_items = min(max_num_mm_items_encoder_budget,
+ max_num_mm_items_decoder_budget)
+
# Dummy data definition in V0 may contain multiple multimodal items
# (e.g, multiple images) for a single request, therefore here we
# always replicate first item by max_num_mm_items times since in V1