diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 613343281464c..f74c201bdff6b 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -570,28 +570,28 @@ See [this page](#generative-models) for more information on how to use generativ - `rhymes-ai/Aria` - - ✅︎ - - + - ✅︎ * - `Blip2ForConditionalGeneration` - BLIP-2 - T + IE - `Salesforce/blip2-opt-2.7b`, `Salesforce/blip2-opt-6.7b`, etc. - - ✅︎ - - + - ✅︎ * - `ChameleonForConditionalGeneration` - Chameleon - T + I - `facebook/chameleon-7b` etc. - - ✅︎ - - + - ✅︎ * - `FuyuForCausalLM` - Fuyu - T + I - `adept/fuyu-8b` etc. - - ✅︎ - - + - ✅︎ * - `ChatGLMModel` - GLM-4V - T + I @@ -633,7 +633,7 @@ See [this page](#generative-models) for more information on how to use generativ - `llava-hf/llava-v1.6-mistral-7b-hf`, `llava-hf/llava-v1.6-vicuna-7b-hf`, etc. - - ✅︎ - - + - ✅︎ * - `LlavaNextVideoForConditionalGeneration` - LLaVA-NeXT-Video - T + V diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 77af914a6ef02..b51bfae455267 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -24,10 +24,13 @@ def run_aria(question: str, modality: str): assert modality == "image" model_name = "rhymes-ai/Aria" + # NOTE: Need L40 (or equivalent) to avoid OOM llm = LLM(model=model_name, tokenizer_mode="slow", - trust_remote_code=True, dtype="bfloat16", + max_model_len=4096, + max_num_seqs=2, + trust_remote_code=True, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache) prompt = (f"<|im_start|>user\n<|img|>\n{question}" @@ -57,6 +60,7 @@ def run_chameleon(question: str, modality: str): prompt = f"{question}" llm = LLM(model="facebook/chameleon-7b", max_model_len=4096, + max_num_seqs=2, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache) stop_token_ids = None return llm, prompt, stop_token_ids @@ -257,7 +261,7 @@ def run_minicpmv(question: str, modality: str): # 2.5 # model_name = "openbmb/MiniCPM-Llama3-V-2_5" - #2.6 + # 2.6 model_name = "openbmb/MiniCPM-V-2_6" tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) @@ -430,9 +434,11 @@ def run_pixtral_hf(question: str, modality: str): model_name = "mistral-community/pixtral-12b" + # NOTE: Need L40 (or equivalent) to avoid OOM llm = LLM( model=model_name, max_model_len=8192, + max_num_seqs=2, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, ) diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py index 1a9c1b4ef1be0..7db08166826eb 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/decoder_only/vision_language/test_models.py @@ -140,10 +140,7 @@ "aria": VLMTestInfo( models=["rhymes-ai/Aria"], tokenizer_mode="slow", - test_type=( - VLMTestType.IMAGE, - VLMTestType.MULTI_IMAGE, - ), + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), dtype="bfloat16", prompt_formatter=lambda img_prompt: f"<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n ", # noqa: E501 img_idx_to_prompt=lambda idx: "<|img|>\n", @@ -179,6 +176,7 @@ test_type=VLMTestType.IMAGE, prompt_formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:", max_model_len=4096, + max_num_seqs=2, auto_cls=AutoModelForVision2Seq, postprocess_inputs=model_utils.cast_dtype_post_processor( "pixel_values" @@ -201,7 +199,6 @@ vllm_output_post_proc=model_utils.fuyu_vllm_to_hf_output, num_logprobs=10, image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], - marks=[large_gpu_mark(min_gb=48)], ), "glm4": VLMTestInfo( models=["THUDM/glm-4v-9b"], diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index 1b2847ed0f534..81278cde264ff 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -528,7 +528,7 @@ def _rand_audio( def _test_processing_cache_correctness( model_id: str, - modalities: set[str], + modalities: dict[str, bool], hit_rate: float, num_batches: int, simplify_rate: float, @@ -583,9 +583,8 @@ def _test_processing_cache_correctness( partial(_rand_audio, rng, min_len=256, max_len=512, sr=16000), } input_max_count = { - "image": 3, - "video": 3, - "audio": 3, + modality: 3 if supports_multi else 1 + for modality, supports_multi in modalities.items() } for batch_idx in range(num_batches): @@ -624,12 +623,16 @@ def _test_processing_cache_correctness( # yapf: disable @pytest.mark.parametrize(("model_id", "modalities"), [ - ("llava-hf/llava-1.5-7b-hf", {"image"}), - ("TIGER-Lab/Mantis-8B-siglip-llama3", {"image"}), - ("mistral-community/pixtral-12b", {"image"}), - ("Qwen/Qwen2-VL-2B-Instruct", {"image", "video"}), - ("Qwen/Qwen2-Audio-7B-Instruct", {"audio"}), - ("fixie-ai/ultravox-v0_3", {"audio"}), + ("rhymes-ai/Aria", {"image": True}), + ("Salesforce/blip2-opt-2.7b", {"image": False}), + ("facebook/chameleon-7b", {"image": True}), + ("adept/fuyu-8b", {"image": False}), + ("llava-hf/llava-1.5-7b-hf", {"image": True}), + ("TIGER-Lab/Mantis-8B-siglip-llama3", {"image": True}), + ("mistral-community/pixtral-12b", {"image": True}), + ("Qwen/Qwen2-VL-2B-Instruct", {"image": True, "video": True}), + ("Qwen/Qwen2-Audio-7B-Instruct", {"audio": True}), + ("fixie-ai/ultravox-v0_3", {"audio": True}), ]) @pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0]) @pytest.mark.parametrize("num_batches", [32]) @@ -637,7 +640,7 @@ def _test_processing_cache_correctness( # yapf: enable def test_processing_cache_correctness( model_id: str, - modalities: set[str], + modalities: dict[str, bool], hit_rate: float, num_batches: int, simplify_rate: float, @@ -653,7 +656,7 @@ def test_processing_cache_correctness( # yapf: disable @pytest.mark.parametrize(("model_id", "modalities"), [ - ("microsoft/Phi-3-vision-128k-instruct", {"image"}), + ("microsoft/Phi-3-vision-128k-instruct", {"image": True}), ]) @pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0]) @pytest.mark.parametrize("num_batches", [32]) @@ -661,7 +664,7 @@ def test_processing_cache_correctness( # yapf: enable def test_processing_cache_correctness_phi3v( model_id: str, - modalities: set[str], + modalities: dict[str, bool], hit_rate: float, num_batches: int, simplify_rate: float, diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index 9437ad9688422..4ad6e859f4d93 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -1,15 +1,15 @@ -import math -from typing import Iterable, List, Optional, Set, Tuple, TypedDict, Union +from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict, + Union) import torch import torch.nn as nn from torch.nn.init import trunc_normal_ -from transformers import LlamaConfig +from transformers import BatchFeature, PretrainedConfig from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, QuantizationConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_rank -from vllm.inputs import INPUT_REGISTRY, token_inputs +from vllm.inputs import InputContext from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -17,30 +17,27 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( get_compressed_tensors_cache_scale) -from vllm.model_executor.layers.sampler import (Sampler, SamplerOutput, - SamplingMetadata) +from vllm.model_executor.layers.sampler import (SamplerOutput, + SamplingMetadata, get_sampler) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.models.idefics2_vision_model import ( - Idefics2VisionTransformer) -from vllm.model_executor.models.interfaces import SupportsMultiModal -from vllm.model_executor.models.llama import (LlamaDecoderLayer, LlamaMLP, - LlamaModel) -from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, - is_pp_missing_parameter, - maybe_prefix, - merge_multimodal_embeddings) from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.image import cached_get_image_processor -from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors -from vllm.multimodal.utils import (cached_get_tokenizer, - repeat_and_pad_placeholder_tokens) +from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, + NestedTensors) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + MultiModalDataItems, ProcessorInputs, + PromptReplacement) from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.aria import (AriaMoELMConfig, AriaVisionConfig) -from .utils import flatten_bn +from .idefics2_vision_model import Idefics2VisionTransformer +from .interfaces import SupportsMultiModal +from .llama import LlamaDecoderLayer, LlamaMLP, LlamaModel +from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, + is_pp_missing_parameter, maybe_prefix, + merge_multimodal_embeddings) class AriaImagePixelInputs(TypedDict): @@ -251,7 +248,7 @@ def forward(self, x, attn_mask=None): class AriaFusedMoE(FusedMoE): def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, - shard_id: str) -> Set[str]: + shard_id: str) -> None: # Override the weight_loader to handle the expert weights in the Aria # model, which are already packed with experts, and merge the gate and # up weights for each expert. @@ -346,7 +343,7 @@ class MoEDecoderLayer(LlamaDecoderLayer): def __init__( self, - config: LlamaConfig, + config: AriaMoELMConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -434,7 +431,7 @@ def load_weights(self, weights: Iterable[Tuple[str, return loaded_params -def build_mm_projector(config): +def build_mm_projector(config: PretrainedConfig): return AriaProjector( patch_to_query_dict=config.projector_patch_to_query_dict, embed_dim=config.vision_config.hidden_size, @@ -445,75 +442,70 @@ def build_mm_projector(config): ) -def get_max_multimodal_tokens(ctx): - return max(ctx.model_config.hf_config.image_size2tokens.values()) - - -def input_mapper_for_aria(ctx, data): - return MultiModalKwargs(data) +def get_max_aria_image_tokens(ctx: InputContext): + hf_config = ctx.get_hf_config() + return max(hf_config.projector_patch_to_query_dict.values()) -def input_processor(ctx, llm_inputs): - multi_modal_data = llm_inputs.get("multi_modal_data") - # if it is pure text input, use it as is - if multi_modal_data is None or "image" not in multi_modal_data: - return llm_inputs +class AriaMultiModalProcessor(BaseMultiModalProcessor): - model_config = ctx.model_config - - tokenizer = cached_get_tokenizer(model_config.tokenizer) - image_processor = cached_get_image_processor( - model_config.model, trust_remote_code=model_config.trust_remote_code) - hf_config = model_config.hf_config + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + pixel_mask=MultiModalFieldConfig.batched("image"), + ) - # prepare image tokens, the max_image_size is used to determine the number - # of patch_size for every image - max_image_size = multi_modal_data.pop("max_image_size", 980) - _split_image = multi_modal_data.pop("split_image", False) + def _get_prompt_replacements( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> list[PromptReplacement]: + hf_config = self.ctx.get_hf_config() + image_token_id = hf_config.image_token_index + + max_image_tokens = get_max_aria_image_tokens(self.ctx) + + return [ + PromptReplacement( + modality="image", + target=[image_token_id], + replacement=[image_token_id] * max_image_tokens, + ) + ] - assert isinstance(max_image_size, - (int, float)), "max_image_size should be float or int" - images = (multi_modal_data["image"] if isinstance( - multi_modal_data["image"], list) else [multi_modal_data["image"]]) + def _get_dummy_mm_inputs( + self, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + hf_config = self.ctx.get_hf_config() + vision_config: AriaVisionConfig = hf_config.vision_config + + max_image_size = vision_config.image_size + num_images = mm_counts.get("image", 0) + + mm_data = { + "image": + self._get_dummy_images(width=max_image_size, + height=max_image_size, + num_images=num_images) + } - image_inputs = image_processor.preprocess(images, - max_image_size=max_image_size, - split_image=_split_image, - return_tensors="pt").data - image_inputs['pixel_values'] = image_inputs['pixel_values'].to( - ctx.model_config.dtype) - num_crops = image_inputs.pop("num_crops") + hf_processor = self._get_hf_processor() + image_token: str = hf_processor.image_token # type: ignore - prompt_token_ids = llm_inputs["prompt_token_ids"] - if num_crops.sum().item() > 0: - _, prompt_token_ids, _ = repeat_and_pad_placeholder_tokens( - tokenizer, - None, - prompt_token_ids, - placeholder_token_id=hf_config.image_token_index, - repeat_count=num_crops, + return ProcessorInputs( + prompt_text=image_token * num_images, + mm_data=mm_data, ) - repeat_count = [hf_config.image_size2tokens[max_image_size] - ] * sum(num_crops).item() - new_prompt, new_token_ids, _ = repeat_and_pad_placeholder_tokens( - tokenizer, - None, - prompt_token_ids, - placeholder_token_id=hf_config.image_token_index, - repeat_count=repeat_count, - ) - - return token_inputs( - prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data={"image": image_inputs}, - ) - -@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_multimodal_tokens) -@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_aria) -@INPUT_REGISTRY.register_input_processor(input_processor) +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_aria_image_tokens) +@MULTIMODAL_REGISTRY.register_processor(AriaMultiModalProcessor) class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): """ Aria model for conditional generation tasks. @@ -540,12 +532,6 @@ def __init__( config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - # prepare the image_size to tokens mapping for the image preprocess, see - # input_processor - config.image_size2tokens = { - int(math.sqrt(k) * config.vision_config.patch_size): v - for k, v in config.projector_patch_to_query_dict.items() - } self.config = config self.vision_tower = AriaVisionModel(config.vision_config) self.multi_modal_projector = build_mm_projector(config) @@ -566,7 +552,7 @@ def __init__( logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.vocab_size, logit_scale) - self.sampler = Sampler() + self.sampler = get_sampler() def _validate_image_sizes( self, images: List[torch.Tensor]) -> List[torch.Tensor]: @@ -588,7 +574,12 @@ def _parse_and_validate_image_input( pixel_values = self._validate_image_sizes(pixel_values) pixel_values = flatten_bn(pixel_values, concat=True) + if pixel_mask is not None: + if not isinstance(pixel_mask, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel mask. " + f"Got type: {type(pixel_mask)}") + pixel_mask = flatten_bn(pixel_mask, concat=True) return AriaImagePixelInputs( diff --git a/vllm/model_executor/models/blip.py b/vllm/model_executor/models/blip.py index 42a239cadac46..987dfaf44f228 100644 --- a/vllm/model_executor/models/blip.py +++ b/vllm/model_executor/models/blip.py @@ -4,22 +4,16 @@ import torch import torch.nn as nn -from PIL import Image from transformers import Blip2VisionConfig, BlipVisionConfig from vllm.attention.layer import MultiHeadAttention -from vllm.config import ModelConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size -from vllm.inputs import DecoderOnlyInputs, token_inputs from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.multimodal.utils import (cached_get_tokenizer, - repeat_and_pad_placeholder_tokens) -from vllm.sequence import SequenceData def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int: @@ -33,92 +27,6 @@ def get_blip_num_patches(*, image_size: int, patch_size: int) -> int: return grid_length * grid_length -def get_blip_image_feature_size( - hf_config: Union[BlipVisionConfig, Blip2VisionConfig]) -> int: - return get_blip_num_patches(image_size=hf_config.image_size, - patch_size=hf_config.patch_size) - - -def get_max_blip_image_tokens( - hf_config: Union[BlipVisionConfig, Blip2VisionConfig]) -> int: - return get_blip_image_feature_size(hf_config) - - -def dummy_seq_data_for_blip( - hf_config: Union[BlipVisionConfig, Blip2VisionConfig], - seq_len: int, - num_images: int, - *, - image_token_id: int, - image_feature_size_override: Optional[int] = None, -): - if image_feature_size_override is None: - image_feature_size = get_blip_image_feature_size(hf_config) - else: - image_feature_size = image_feature_size_override - - return SequenceData.from_prompt_token_counts( - (image_token_id, image_feature_size * num_images), - (0, seq_len - image_feature_size * num_images), - ) - - -def dummy_image_for_blip( - hf_config: Union[BlipVisionConfig, Blip2VisionConfig], - num_images: int, - *, - image_width_override: Optional[int] = None, - image_height_override: Optional[int] = None, -): - width = height = hf_config.image_size - if image_width_override is not None: - width = image_width_override - if image_height_override is not None: - height = image_height_override - - image = Image.new("RGB", (width, height), color=0) - return {"image": image if num_images == 1 else [image] * num_images} - - -def input_processor_for_blip( - model_config: ModelConfig, - hf_config: Union[BlipVisionConfig, Blip2VisionConfig], - inputs: DecoderOnlyInputs, - *, - image_token_id: int, - image_feature_size_override: Optional[int] = None, -): - multi_modal_data = inputs.get("multi_modal_data") - if multi_modal_data is None or "image" not in multi_modal_data: - return inputs - - if "multi_modal_placeholders" in inputs and "image" in inputs[ - "multi_modal_placeholders"]: - # The inputs already have placeholders. - return inputs - - tokenizer = cached_get_tokenizer(model_config.tokenizer) - - if image_feature_size_override is None: - image_feature_size = get_blip_image_feature_size(hf_config) - else: - image_feature_size = image_feature_size_override - - new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens( - tokenizer, - inputs.get("prompt"), - inputs["prompt_token_ids"], - placeholder_token_id=image_token_id, - repeat_count=image_feature_size, - ) - - # NOTE: Create a defensive copy of the original inputs - return token_inputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data, - multi_modal_placeholders={"image": ranges}) - - # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa class BlipVisionEmbeddings(nn.Module): diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 76b8505ee1c2a..bf70f5d904f5b 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -4,32 +4,33 @@ import torch import torch.nn as nn -from transformers import (Blip2Config, Blip2QFormerConfig, Blip2VisionConfig, - apply_chunking_to_forward) +from transformers import (BatchFeature, Blip2Config, Blip2Processor, + Blip2QFormerConfig, apply_chunking_to_forward) from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, VllmConfig -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, - InputContext, token_inputs) +from vllm.inputs import InputContext from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import NestedTensors -from vllm.multimodal.utils import consecutive_placeholder_ranges -from vllm.sequence import IntermediateTensors, SequenceData - -from .blip import (BlipVisionModel, dummy_image_for_blip, - get_max_blip_image_tokens) +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalInputsV2, MultiModalKwargs, + NestedTensors, PlaceholderRange) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + MultiModalDataItems, ProcessorInputs, + PromptReplacement) +from vllm.sequence import IntermediateTensors + +from .blip import BlipVisionModel from .interfaces import SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) # We use this internally as placeholders since there is no image token # defined on the HuggingFace repo -BLIP2_IMAGE_TOKEN = "" -BLIP2_IMAGE_TOKEN_ID = 50265 +_IMAGE_TOKEN_ID = 50265 class Blip2ImagePixelInputs(TypedDict): @@ -396,92 +397,87 @@ def forward( return sequence_output -def get_blip2_image_feature_size(hf_config: Blip2Config) -> int: - return hf_config.num_query_tokens - - def get_max_blip2_image_tokens(ctx: InputContext): hf_config = ctx.get_hf_config(Blip2Config) - vision_config = hf_config.vision_config - - if isinstance(vision_config, Blip2VisionConfig): - return get_max_blip_image_tokens(vision_config) - - msg = f"Unsupported vision config: {type(vision_config)}" - raise NotImplementedError(msg) - - -def dummy_seq_data_for_blip2( - hf_config: Blip2Config, - seq_len: int, - num_images: int, - *, - image_token_id: int, - image_feature_size_override: Optional[int] = None, -): - if image_feature_size_override is None: - image_feature_size = get_blip2_image_feature_size(hf_config) - else: - image_feature_size = image_feature_size_override - - return SequenceData.from_prompt_token_counts( - (image_token_id, image_feature_size * num_images), - (0, seq_len - image_feature_size * num_images), - ), { - "image": - consecutive_placeholder_ranges(num_items=num_images, - item_size=image_feature_size) - } - - -def dummy_data_for_blip2(ctx: InputContext, seq_len: int, - mm_counts: Mapping[str, int]): - hf_config = ctx.get_hf_config(Blip2Config) - vision_config = hf_config.vision_config - num_images = mm_counts["image"] - - seq_data, ranges = dummy_seq_data_for_blip2( - hf_config, - seq_len, - num_images, - image_token_id=BLIP2_IMAGE_TOKEN_ID, - ) - - if isinstance(vision_config, Blip2VisionConfig): - mm_data = dummy_image_for_blip(vision_config, num_images) - - return DummyData(seq_data, mm_data, ranges) - - msg = f"Unsupported vision config: {type(vision_config)}" - raise NotImplementedError(msg) + return hf_config.num_query_tokens -def input_processor_for_blip2(ctx: InputContext, inputs: DecoderOnlyInputs): - multi_modal_data = inputs.get("multi_modal_data") - if multi_modal_data is None or "image" not in multi_modal_data: - return inputs +class Blip2MultiModalProcessor(BaseMultiModalProcessor): - hf_config = ctx.get_hf_config(Blip2Config) - image_feature_size = get_blip2_image_feature_size(hf_config) + def _get_hf_processor(self) -> Blip2Processor: + return self.ctx.get_hf_processor(Blip2Processor) - # The original model places image tokens at the front - # https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1514 - new_token_ids = [BLIP2_IMAGE_TOKEN_ID] * image_feature_size - new_token_ids += inputs["prompt_token_ids"] + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + image_embeds=MultiModalFieldConfig.batched("image"), + ) - new_prompt = inputs.get("prompt") - if new_prompt is not None: - new_prompt = BLIP2_IMAGE_TOKEN * image_feature_size + new_prompt + def _get_prompt_replacements( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> list[PromptReplacement]: + max_image_tokens = get_max_blip2_image_tokens(self.ctx) + + return [ + PromptReplacement( + modality="image", + target="", + replacement="" * max_image_tokens + "", + ) + ] - return token_inputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data) + def apply( + self, + prompt_text: str, + mm_data: MultiModalDataDict, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> MultiModalInputsV2: + result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs) + + # Only tokens should be considered as placeholders, + # so we ignore the trailing bos_token + result["mm_placeholders"] = { + modality: [ + PlaceholderRange(offset=p["offset"], length=p["length"] - 1) + for p in ps + ] + for modality, ps in result["mm_placeholders"].items() + } + + return result + + def _get_dummy_mm_inputs( + self, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + hf_config = self.ctx.get_hf_config(Blip2Config) + vision_config = hf_config.vision_config + + max_image_size = vision_config.image_size + num_images = mm_counts.get("image", 0) + + mm_data = { + "image": + self._get_dummy_images(width=max_image_size, + height=max_image_size, + num_images=num_images) + } + + return ProcessorInputs( + prompt_text="", + mm_data=mm_data, + ) -@MULTIMODAL_REGISTRY.register_image_input_mapper() @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_blip2_image_tokens) -@INPUT_REGISTRY.register_dummy_data(dummy_data_for_blip2) -@INPUT_REGISTRY.register_input_processor(input_processor_for_blip2) +@MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor) class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -627,7 +623,7 @@ def get_input_embeddings( if multimodal_embeddings is not None: inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, multimodal_embeddings, - BLIP2_IMAGE_TOKEN_ID) + _IMAGE_TOKEN_ID) return inputs_embeds def forward( diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index a40c321ce0a58..85fca23b05746 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -3,16 +3,15 @@ Tuple, TypedDict, Union) import torch +import torch.nn as nn import torch.nn.functional as F -from PIL import Image -from torch import nn -from transformers import ChameleonConfig, ChameleonVQVAEConfig +from transformers import (BatchFeature, ChameleonConfig, ChameleonProcessor, + ChameleonVQVAEConfig) from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, - InputContext, token_inputs) +from vllm.inputs import InputContext from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -29,11 +28,13 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import NestedTensors -from vllm.multimodal.utils import (cached_get_tokenizer, - consecutive_placeholder_ranges, - repeat_and_pad_placeholder_tokens) -from vllm.sequence import IntermediateTensors, SequenceData +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalInputsV2, MultiModalKwargs, + NestedTensors, PlaceholderRange) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + MultiModalDataItems, ProcessorInputs, + PromptReplacement) +from vllm.sequence import IntermediateTensors from vllm.utils import print_warning_once from .interfaces import SupportsMultiModal, SupportsPP @@ -45,10 +46,6 @@ # and processor files, so we hardcode them in the model file for now. CHAMELEON_CROP_SIZE_HEIGHT = CHAMELEON_CROP_SIZE_WIDTH = 512 CHAMELEON_IMAGE_SEQ_LENGTH = 1024 -CHAMELEON_IMAGE_TOKEN_ID = 8711 -CHAMELEON_IMAGE_START_TOKEN_ID = 8197 -CHAMELEON_IMAGE_END_TOKEN_ID = 8196 -CHAMELEON_SEP_TOKEN_ID = 8710 class ChameleonImagePixelInputs(TypedDict): @@ -61,99 +58,75 @@ def get_max_chameleon_image_tokens(ctx: InputContext): return CHAMELEON_IMAGE_SEQ_LENGTH -def dummy_seq_data_for_chameleon( - seq_len: int, - num_images: int, - *, - image_token_id: int, - image_feature_size_override: Optional[int] = None, -): - if image_feature_size_override is None: - image_feature_size = CHAMELEON_IMAGE_SEQ_LENGTH - else: - image_feature_size = image_feature_size_override - - return SequenceData.from_prompt_token_counts( - (image_token_id, image_feature_size * num_images), - (0, seq_len - image_feature_size * num_images), - ), { - "image": - consecutive_placeholder_ranges(num_items=num_images, - item_size=image_feature_size) - } - - -def dummy_image_for_chameleon( - num_images: int, - *, - image_width_override: Optional[int] = None, - image_height_override: Optional[int] = None, -): - width = CHAMELEON_CROP_SIZE_WIDTH - height = CHAMELEON_CROP_SIZE_HEIGHT - if image_width_override is not None: - width = image_width_override - if image_height_override is not None: - height = image_height_override - - image = Image.new("RGB", (width, height), color=0) - return {"image": image if num_images == 1 else [image] * num_images} - - -def dummy_data_for_chameleon(ctx: InputContext, seq_len: int, - mm_counts: Mapping[str, int]): - num_images = mm_counts["image"] - - seq_data, ranges = dummy_seq_data_for_chameleon( - seq_len, - num_images, - image_token_id=CHAMELEON_IMAGE_TOKEN_ID, - ) - - mm_data = dummy_image_for_chameleon(num_images) - return DummyData(seq_data, mm_data, ranges) - - -def input_processor_for_chameleon(ctx: InputContext, - inputs: DecoderOnlyInputs): +class ChameleonMultiModalProcessor(BaseMultiModalProcessor): - """ - Processing input prompt to insert required tokens for image placeholder. - - See https://github.com/huggingface/transformers/blob/0fdea8607d7e01eb0e38a1ebeb7feee30a22f0cf/src/transformers/models/chameleon/processing_chameleon.py#L58 - """ # noqa - - multi_modal_data = inputs.get("multi_modal_data") - if multi_modal_data is None or "image" not in multi_modal_data: - return inputs - - if "multi_modal_placeholders" in inputs and "image" in inputs[ - "multi_modal_placeholders"]: - # The inputs already have placeholders. - return inputs - - model_config = ctx.model_config - tokenizer = cached_get_tokenizer(model_config.tokenizer) - new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens( - tokenizer, - inputs.get("prompt"), - inputs["prompt_token_ids"], - placeholder_token_id=CHAMELEON_IMAGE_TOKEN_ID, - repeat_count=CHAMELEON_IMAGE_SEQ_LENGTH, - pad_token_left=CHAMELEON_IMAGE_START_TOKEN_ID, - pad_token_right=CHAMELEON_IMAGE_END_TOKEN_ID, - ) - - # Appending sep token for chat mode to follow default processor - # behavior - if new_prompt is not None: - new_prompt += tokenizer.sep_token - new_token_ids += [CHAMELEON_SEP_TOKEN_ID] - - # NOTE: Create a defensive copy of the original inputs - return token_inputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data) + def _get_hf_processor(self) -> ChameleonProcessor: + return self.ctx.get_hf_processor(ChameleonProcessor) + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict(pixel_values=MultiModalFieldConfig.batched("image")) + + def _get_prompt_replacements( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> list[PromptReplacement]: + processor = self._get_hf_processor() + + return [ + PromptReplacement( + modality="image", + target="", + replacement="".join([ + processor.image_start_token, + processor.image_token * CHAMELEON_IMAGE_SEQ_LENGTH, + processor.image_end_token, + ]), + ) + ] + + def _get_dummy_mm_inputs( + self, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + num_images = mm_counts.get("image", 0) + + mm_data = { + "image": + self._get_dummy_images(width=CHAMELEON_CROP_SIZE_WIDTH, + height=CHAMELEON_CROP_SIZE_HEIGHT, + num_images=num_images) + } + + return ProcessorInputs( + prompt_text="" * num_images, + mm_data=mm_data, + ) + + def apply( + self, + prompt_text: str, + mm_data: MultiModalDataDict, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> MultiModalInputsV2: + result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs) + + # Only tokens should be considered as placeholders, + # so we ignore the image_start_token and image_end_token + result["mm_placeholders"] = { + modality: [ + PlaceholderRange(offset=p["offset"] + 1, + length=p["length"] - 2) for p in ps + ] + for modality, ps in result["mm_placeholders"].items() + } + + return result class ChameleonLayerNorm(nn.LayerNorm): @@ -736,7 +709,7 @@ def forward(self, pixel_values: torch.Tensor): for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): hidden_state = self.down[i_level].block[i_block]( - hidden_states[-1], ) + hidden_states[-1]) if len(self.down[i_level].attn) > 0: hidden_state = self.down[i_level].attn[i_block]( hidden_state) @@ -925,10 +898,8 @@ def forward( return hidden_states -@MULTIMODAL_REGISTRY.register_image_input_mapper() @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_chameleon_image_tokens) -@INPUT_REGISTRY.register_dummy_data(dummy_data_for_chameleon) -@INPUT_REGISTRY.register_input_processor(input_processor_for_chameleon) +@MULTIMODAL_REGISTRY.register_processor(ChameleonMultiModalProcessor) class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 6e86900326c4b..8c14866f20b92 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -15,32 +15,30 @@ # limitations under the License. """ PyTorch Fuyu model.""" import math -from array import array from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, TypedDict) import torch import torch.nn as nn -import torch.utils.checkpoint -from PIL import Image -from transformers import FuyuImageProcessor +from transformers import (BatchFeature, FuyuConfig, FuyuImageProcessor, + FuyuProcessor) from vllm.attention import AttentionMetadata from vllm.config import VllmConfig -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, - InputContext, token_inputs) +from vllm.inputs import InputContext from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.models.persimmon import PersimmonForCausalLM from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs -from vllm.multimodal.image import cached_get_image_processor -from vllm.multimodal.inputs import NestedTensors -from vllm.multimodal.utils import (cached_get_tokenizer, - consecutive_placeholder_ranges) -from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, - SequenceData) -from vllm.utils import is_list_of +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalInputsV2, MultiModalKwargs, + NestedTensors, PlaceholderRange) +from vllm.multimodal.parse import ImageProcessorItems +from vllm.multimodal.processing import (BaseMultiModalProcessor, + MultiModalDataItems, ProcessorInputs, + PromptReplacement) +from vllm.sequence import IntermediateTensors from .interfaces import SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, @@ -54,178 +52,193 @@ MAX_IMAGE_FEATURE_SIZE_WIDTH = 1920 -class FuyuImagePixelInputs(TypedDict): - type: Literal["pixel_values"] +class FuyuImagePatchInputs(TypedDict): + type: Literal["image_patches"] data: torch.Tensor """ Shape: - (batch_size, num_patches, patch_size_x * patch_size_y * num_channels) + `(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)` + """ + + patches_per_image: List[int] + """ + List of number of total patches for each image in the batch. + This is used to restore the first two dimensions of `data`. """ -def _calculate_num_image_tokens( - height: int, - width: int, +def _get_fuyu_num_image_tokens( + image_height: int, + image_width: int, ) -> Tuple[int, int]: """ - calculate number of image tokens needed for a given image size - The expected Fuyu image prompts is in format: - (image_token * ncols + newline_token) * nrows - args: - image_size: Tuple[int, int] - (width, height) of the image - returns: - ncols: int - number of image tokens in x direction - nrows: int - number of image tokens in y direction - """ - ncol = math.ceil(width / 30) - nrow = math.ceil(height / 30) - return ncol, nrow + Calculate the number of image tokens needed for a given image size. + The expected Fuyu image prompts can be expressed as: -def get_max_fuyu_image_feature_size(): + .. code-block:: + (image_token * ncols + newline_token) * nrows - return _calculate_num_image_tokens( - height=MAX_IMAGE_FEATURE_SIZE_HEIGHT, - width=MAX_IMAGE_FEATURE_SIZE_WIDTH, - ) + Args: + image_size: Tuple[int, int] - `(width, height)` of the image + + Returns: + ncols: int - number of image tokens in `x` direction + nrows: int - number of image tokens in `y` direction + """ + ncols = math.ceil(image_width / 30) + nrows = math.ceil(image_height / 30) + return ncols, nrows def get_max_fuyu_image_tokens(ctx: InputContext): - ncol, nrow = get_max_fuyu_image_feature_size() - return (ncol + 1) * nrow - - -def dummy_seq_data_for_fuyu(ctx: InputContext, seq_len: int, num_images: int): - ncol, nrow = get_max_fuyu_image_feature_size() - image_feature_size = get_max_fuyu_image_tokens(ctx) - - image_token_ids = ( - array(VLLM_TOKEN_ID_ARRAY_TYPE, [_IMAGE_TOKEN_ID]) * ncol + - array(VLLM_TOKEN_ID_ARRAY_TYPE, [_NEWLINE_TOKEN_ID])) * nrow - token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, image_token_ids) * num_images - token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, - [0]) * (seq_len - image_feature_size * num_images) - return SequenceData(token_ids), { - "image": - consecutive_placeholder_ranges(num_items=num_images, - item_size=image_feature_size) - } - - -def dummy_image_for_fuyu( - num_images: int, - *, - image_width: int, - image_height: int, -): - image = Image.new("RGB", (image_width, image_height), color=0) - return {"image": image if num_images == 1 else [image] * num_images} - - -def dummy_data_for_fuyu(ctx: InputContext, seq_len: int, - mm_counts: Mapping[str, int]): - num_images = mm_counts["image"] - seq_data, ranges = dummy_seq_data_for_fuyu(ctx, seq_len, num_images) - mm_data = dummy_image_for_fuyu(num_images, - image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH, - image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT) - return DummyData(seq_data, mm_data, ranges) - - -def _fuyu_image_preprocess(image_processor: FuyuImageProcessor, - data: List[Image.Image]): - image_encoding = image_processor.preprocess(data, return_tensors="pt") - batch_images = torch.stack([img[0] for img in image_encoding["images"] - ]).unsqueeze(1) - image_unpadded_heights = torch.tensor( - image_encoding["image_unpadded_heights"]) - image_unpadded_widths = torch.tensor( - image_encoding["image_unpadded_widths"]) - - batch_size = len(image_encoding["images"]) - image_present = torch.ones(batch_size, 1, 1) - model_image_input = image_processor.preprocess_with_tokenizer_info( - image_input=batch_images, - image_present=image_present, - image_unpadded_h=image_unpadded_heights, - image_unpadded_w=image_unpadded_widths, - image_placeholder_id=_IMAGE_TOKEN_ID, - image_newline_id=_NEWLINE_TOKEN_ID, - variable_sized=True, + ncols, nrows = _get_fuyu_num_image_tokens( + image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT, + image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH, ) - return model_image_input - - -def input_processor_for_fuyu(ctx: InputContext, inputs: DecoderOnlyInputs): - multi_modal_data = inputs.get("multi_modal_data") - if multi_modal_data is None or "image" not in multi_modal_data: - return inputs - - model_config = ctx.model_config - image_data = multi_modal_data["image"] - new_multi_modal_data = {} - image_list = image_data if isinstance(image_data, list) else [image_data] - - # process image data - if is_list_of(image_list, Image.Image): - # Fuyu's image_processor can also finish token padding - image_processor: FuyuImageProcessor = cached_get_image_processor( - model_config.model) - - model_image_input = _fuyu_image_preprocess(image_processor, image_data) - image_patches = torch.cat([ - image_patch[0] - for image_patch in model_image_input["image_patches"] - ]) - new_multi_modal_data["image"] = image_patches - - elif is_list_of(image_list, torch.Tensor): - raise NotImplementedError("Embeddings input is not supported yet") - else: - raise TypeError(f"Invalid image type: {type(image_data)}") - - # process prompts - prompt = inputs.get("prompt") - prompt_token_ids = inputs["prompt_token_ids"] - tokenizer = cached_get_tokenizer(model_config.model) - # dim0 is batch_size, dim1 is subseq_size which will always be 1 - image_input_ids: List[List[ - torch.Tensor]] = model_image_input["image_input_ids"] - image_input_ids = image_input_ids[0][0].tolist() - bos_token = tokenizer.encode("", add_special_tokens=False)[1:] - boa_token = tokenizer.encode("\x04", add_special_tokens=False)[1:] - - new_prompt = prompt + "\x04" - new_prompt_token_ids = image_input_ids + bos_token + prompt_token_ids[ - 1:] + boa_token - - return token_inputs(prompt=new_prompt, - prompt_token_ids=new_prompt_token_ids, - multi_modal_data=new_multi_modal_data) - - -def input_mapper_for_fuyu(ctx: InputContext, data: object): - model_config = ctx.model_config - data_list = data if isinstance(data, list) else [data] - if is_list_of(data_list, Image.Image): - # Fuyu's image_processor can also finish token padding - image_processor: FuyuImageProcessor = cached_get_image_processor( - model_config.model) - - model_image_input = _fuyu_image_preprocess(image_processor, data_list) - data = torch.stack([ - image_patch[0] - for image_patch in model_image_input["image_patches"] - ]) - - # image has been processed with prompt in input processor - return MultiModalKwargs({"pixel_values": data}) - - -@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_fuyu) + + return (ncols + 1) * nrows + + +class FuyuMultiModalProcessor(BaseMultiModalProcessor): + + def _get_hf_processor(self) -> FuyuProcessor: + return self.ctx.get_hf_processor(FuyuProcessor) + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + + if not mm_data: + # Avoid warning from HF logger for text-only input + # Input_ids format: bos_token_id + prompt_token_ids + boa_token_id + # Tokenizer won't add boa_token_id by default, we add it manually. + tokenizer = self._get_tokenizer() + boa_token_id: int = tokenizer.vocab["<0x04>"] # type: ignore + prompt_ids = tokenizer.encode(prompt) + [boa_token_id] + return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") + + processed_outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + ) + + image_patches = processed_outputs.get("image_patches") + if image_patches is not None: + images = mm_data["images"] + assert isinstance(images, list) + + # Original output: (1, num_images, Pn, Px * Py * C) + # New output: (num_images, Pn, Px * Py * C) + assert (isinstance(image_patches, list) + and len(image_patches) == 1) + assert (isinstance(image_patches[0], torch.Tensor) + and len(image_patches[0]) == len(images)) + + processed_outputs["image_patches"] = image_patches[0] + + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict(image_patches=MultiModalFieldConfig.batched("image")) + + def _get_prompt_replacements( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> list[PromptReplacement]: + hf_config = self.ctx.get_hf_config(FuyuConfig) + bos_token_id = hf_config.bos_token_id + + tokenizer = self._get_tokenizer() + eot_token_id = tokenizer.bos_token_id + assert isinstance(eot_token_id, int) + + hf_processor = self._get_hf_processor() + image_processor: FuyuImageProcessor = hf_processor.image_processor + target_size = image_processor.size + target_height, target_width = (target_size["height"], + target_size["width"]) + + def get_replacement_fuyu(item_idx: int): + images = mm_items.get_items("image", ImageProcessorItems) + image_size = images.get_image_size(item_idx) + width, height = image_size.width, image_size.height + if not (width <= target_width and height <= target_height): + height_scale_factor = target_height / height + width_scale_factor = target_width / width + optimal_scale_factor = min(height_scale_factor, + width_scale_factor) + + height = int(height * optimal_scale_factor) + width = int(width * optimal_scale_factor) + + ncols, nrows = _get_fuyu_num_image_tokens( + image_width=width, + image_height=height, + ) + + return (([_IMAGE_TOKEN_ID] * ncols + [_NEWLINE_TOKEN_ID]) * nrows + + [bos_token_id]) + + return [ + PromptReplacement( + modality="image", + target=[eot_token_id], + replacement=get_replacement_fuyu, + ) + ] + + def apply( + self, + prompt_text: str, + mm_data: MultiModalDataDict, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> MultiModalInputsV2: + result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs) + + # Only |SPEAKER| (image) tokens should be considered as placeholders, + # so we ignore the trailing bos_token_id + result["mm_placeholders"] = { + modality: [ + PlaceholderRange(offset=p["offset"], length=p["length"] - 1) + for p in ps + ] + for modality, ps in result["mm_placeholders"].items() + } + + return result + + def _get_dummy_mm_inputs( + self, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + num_images = mm_counts.get("image", 0) + + mm_data = { + "image": + self._get_dummy_images(width=MAX_IMAGE_FEATURE_SIZE_WIDTH, + height=MAX_IMAGE_FEATURE_SIZE_HEIGHT, + num_images=num_images) + } + + return ProcessorInputs( + prompt_text="", + mm_data=mm_data, + ) + + @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_fuyu_image_tokens) -@INPUT_REGISTRY.register_dummy_data(dummy_data_for_fuyu) -@INPUT_REGISTRY.register_input_processor(input_processor_for_fuyu) +@MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor) class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -280,28 +293,32 @@ def _validate_shape(d: torch.Tensor): return data.to(self.vision_embed_tokens.weight.dtype) def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[FuyuImagePixelInputs]: - pixel_values = kwargs.pop("pixel_values", None) - - if pixel_values is not None: - if not isinstance(pixel_values, (torch.Tensor, list)): + self, **kwargs: object) -> Optional[FuyuImagePatchInputs]: + image_patches = kwargs.pop("image_patches", None) + if image_patches is not None: + if not isinstance(image_patches, (torch.Tensor, list)): raise ValueError("Incorrect type of image patches. " - f"Got type: {type(pixel_values)}") + f"Got type: {type(image_patches)}") - return FuyuImagePixelInputs( - type="pixel_values", + image_patches_flat = flatten_bn(image_patches) + + return FuyuImagePatchInputs( + type="image_patches", data=self._validate_pixel_values( - flatten_bn(pixel_values, concat=True)), + flatten_bn(image_patches_flat, concat=True)), + patches_per_image=[x.size(0) for x in image_patches_flat], ) return None def _process_image_input( - self, image_input: FuyuImagePixelInputs) -> torch.Tensor: + self, image_input: FuyuImagePatchInputs) -> NestedTensors: + image_patches = image_input["data"] + patches_per_image = image_input["patches_per_image"] assert self.vision_embed_tokens is not None - vision_embeddings, _ = self.vision_embed_tokens(image_input["data"]) - return vision_embeddings + vision_embeddings, _ = self.vision_embed_tokens(image_patches) + return vision_embeddings.split(patches_per_image, dim=0) def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: image_input = self._parse_and_validate_image_input(**kwargs) diff --git a/vllm/model_executor/models/idefics2_vision_model.py b/vllm/model_executor/models/idefics2_vision_model.py index e430a158d869a..4e42a4b6f9e64 100644 --- a/vllm/model_executor/models/idefics2_vision_model.py +++ b/vllm/model_executor/models/idefics2_vision_model.py @@ -69,7 +69,8 @@ def forward(self, patch_attention_mask: torch.BoolTensor, tgt_sizes: Optional[torch.IntTensor] = None) -> torch.Tensor: batch_size, _, max_im_h, max_im_w = pixel_values.shape - patch_embeds = self.patch_embedding(pixel_values) + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to(target_dtype)) embeddings = patch_embeds.flatten(2).transpose(1, 2) max_nb_patches_h, max_nb_patches_w = ( max_im_h // self.patch_size, @@ -309,7 +310,8 @@ def forward( hidden_states = self.embeddings( pixel_values=pixel_values, patch_attention_mask=patch_attention_mask, - tgt_sizes=tgt_sizes) + tgt_sizes=tgt_sizes, + ) encoder_outputs = self.encoder(hidden_states) last_hidden_state = self.post_layernorm(encoder_outputs) return last_hidden_state diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 1d6ee2a0be72e..34dc7fa31ce6f 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -144,8 +144,8 @@ def _call_hf_processor( # Original output: (1, num_images, C, H, W) # New output: (num_images, C, H, W) assert (isinstance(pixel_values, list) - and len(pixel_values) == 1 - and isinstance(pixel_values[0], list) + and len(pixel_values) == 1) + assert (isinstance(pixel_values[0], list) and len(pixel_values[0]) == len(images)) processed_outputs["pixel_values"] = pixel_values[0] diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index a39f2f4124d05..5e70c11363c83 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -528,10 +528,8 @@ def _process_image_pixels( stacked_image_features = self._image_pixels_to_features( self.vision_tower, stacked_pixel_values) - return [ - self.multi_modal_projector(image_features) for image_features in - torch.split(stacked_image_features, num_patches_per_batch) - ] + return torch.split(self.multi_modal_projector(stacked_image_features), + num_patches_per_batch) def _process_image_input( self, diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 22d29f5bbc50c..2bce13792a88d 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -1,8 +1,8 @@ +import math from dataclasses import dataclass, fields from functools import cached_property from typing import Iterable, List, Mapping, Optional, Set, Tuple, Union -import numpy import torch import torch.nn as nn import torch.nn.functional as F @@ -306,7 +306,7 @@ def _parse_and_validate_image_input( images: Optional[Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor]] = None, image_tokens: Optional[torch.Tensor] = None, - ) -> Optional[List[torch.Tensor]]: + ) -> Tuple[Optional[List[torch.Tensor]], Optional[torch.Tensor]]: if images is None: return None, None @@ -604,11 +604,11 @@ def max_patches_per_side(self) -> int: return self.args.image_size // self.args.patch_size @property - def device(self) -> torch.device: + def device(self) -> torch.types.Device: return next(self.parameters()).device @property - def dtype(self) -> torch.device: + def dtype(self) -> torch.dtype: return next(self.parameters()).dtype @property @@ -741,8 +741,8 @@ def get_pixtral_hf_image_feature_size(hf_config: PixtralVisionConfig, ratio = max(image_width / max_width, image_height / max_height) if ratio > 1: - image_width = int(numpy.ceil(image_width / ratio)) - image_height = int(numpy.ceil(image_height / ratio)) + image_width = int(math.ceil(image_width / ratio)) + image_height = int(math.ceil(image_height / ratio)) num_height_tokens, num_width_tokens = _get_pixtral_hf_num_image_tokens( (image_height, image_width), diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index e3d43b017f894..de55bc6bcc123 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -23,7 +23,6 @@ from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict, Union) -import numpy as np import torch import torch.nn as nn from transformers import BatchFeature @@ -177,16 +176,19 @@ def _get_dummy_mm_inputs( mm_counts: Mapping[str, int], ) -> ProcessorInputs: feature_extractor = self._get_feature_extractor() + sampling_rate = feature_extractor.sampling_rate audio_len = feature_extractor.chunk_length * sampling_rate + num_audios = mm_counts.get("audio", 0) - audio_count = mm_counts.get("audio", 0) - audio = np.zeros(audio_len) - data = {"audio": [audio] * audio_count} + mm_data = { + "audio": + self._get_dummy_audios(length=audio_len, num_audios=num_audios) + } return ProcessorInputs( - prompt_text="<|AUDIO|>" * audio_count, - mm_data=data, + prompt_text="<|AUDIO|>" * num_audios, + mm_data=mm_data, ) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 6181fe3dd13d8..1e485f87bb7a4 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -29,7 +29,6 @@ import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat -from PIL import Image from transformers import BatchFeature from transformers.models.qwen2_vl import (Qwen2VLImageProcessor, Qwen2VLProcessor) @@ -882,12 +881,10 @@ def _get_dummy_mm_inputs( self, mm_counts: Mapping[str, int], ) -> ProcessorInputs: - num_images = mm_counts.get("image", 0) hf_processor = self._get_hf_processor() - image_token: str = hf_processor.image_token image_processor = _get_image_processor(hf_processor) - data = {} + image_token: str = hf_processor.image_token resized_height, resized_width = smart_resize( height=9999999, width=9999999, @@ -895,14 +892,18 @@ def _get_dummy_mm_inputs( min_pixels=image_processor.min_pixels, max_pixels=image_processor.max_pixels, ) + num_images = mm_counts.get("image", 0) - dummy_image = Image.new("RGB", (resized_width, resized_height), - color=0) - data["image"] = [dummy_image] * num_images + mm_data = { + "image": + self._get_dummy_images(width=resized_width, + height=resized_height, + num_images=num_images) + } return ProcessorInputs( prompt_text=image_token * num_images, - mm_data=data, + mm_data=mm_data, ) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 7e853e5b90096..54be7fed3f2be 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -188,16 +188,19 @@ def _get_dummy_mm_inputs( mm_counts: Mapping[str, int], ) -> ProcessorInputs: feature_extractor = self._get_feature_extractor() + sampling_rate = feature_extractor.sampling_rate audio_len = feature_extractor.chunk_length * sampling_rate + num_audios = mm_counts.get("audio", 0) - audio_count = mm_counts.get("audio", 0) - audio = np.zeros(audio_len) - data = {"audio": [audio] * audio_count} + mm_data = { + "audio": + self._get_dummy_audios(length=audio_len, num_audios=num_audios) + } return ProcessorInputs( - prompt_text="<|audio|>" * audio_count, - mm_data=data, + prompt_text="<|audio|>" * num_audios, + mm_data=mm_data, ) diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 180489166b407..7712c3bcebe20 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1,15 +1,17 @@ import pickle import re from abc import ABC, abstractmethod +from collections import defaultdict from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence from dataclasses import dataclass, field from functools import lru_cache from typing import Any, NamedTuple, Optional, Protocol, TypeVar, Union import numpy as np +import numpy.typing as npt import torch from blake3 import blake3 -from PIL.Image import Image +from PIL import Image from transformers import BatchFeature, ProcessorMixin from vllm.inputs import DummyData, InputProcessingContext @@ -353,13 +355,13 @@ def _replace_matches( ) -> list[_S]: out_seqs = list[_S]() prev_end_idx = 0 - next_idx_by_modality = {modality: 0 for modality in mm_item_counts} + next_idx_by_modality = defaultdict[str, int](lambda: 0) for match in _resolve_matches(prompt, matches): modality = match.modality item_idx = next_idx_by_modality[modality] - if item_idx >= mm_item_counts[modality]: + if item_idx >= mm_item_counts.get(modality, 0): continue start_idx = match.start_idx @@ -513,7 +515,7 @@ def _serialize_item(self, obj: object) -> bytes: return obj.encode("utf-8") if isinstance(obj, bytes): return obj - if isinstance(obj, Image): + if isinstance(obj, Image.Image): return obj.tobytes() # Convertible to NumPy arrays @@ -673,10 +675,14 @@ def _get_prompt_replacements( Given the original multi-modal items for this modality and HF-processed data, output the replacements to perform. - Note: - Even when the HF processor already performs replacement for us, - we still use this replacement information to determine - the placeholder token positions for each multi-modal item. + Notes: + - You should not assume that HF processor always performs prompt + replacement: in :meth:`_apply_hf_processor_missing`, this method + is called on text-only and multimodal-only inputs separately, + instead of passing them in the same call. + - The replacement information returned by this method is also used + to determine the placeholder token positions for each multi-modal + item. """ raise NotImplementedError @@ -710,6 +716,10 @@ def _call_hf_processor( mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], ) -> BatchFeature: + """ + Call the HF processor on the prompt text and + associated multi-modal data. + """ return self.ctx.call_hf_processor( self._get_hf_processor(**mm_kwargs), dict(text=prompt, **mm_data), @@ -723,7 +733,8 @@ def _apply_hf_processor( hf_processor_mm_kwargs: Mapping[str, object], ) -> tuple[list[int], MultiModalKwargs]: """ - Apply the HF processor on the full prompt text and multi-modal data. + Wrapper of :meth:`_call_hf_processor` that applies + additional pre-processing and post-processing. """ processor_data, passthrough_data = self._get_hf_mm_data(mm_items) @@ -754,10 +765,11 @@ def _apply_hf_processor_missing( Apply the HF processor on the full prompt text, but only on the multi-modal data that are missing from the cache. - Note: We pass prompt text and multi-modal data into the HF processor - in separate calls to avoid HF prompt replacement being done for - cached items; instead, we rely on our own prompt replacement logic - for the full text. + Note: + We pass prompt text and multi-modal data into the HF processor + in separate calls to avoid HF prompt replacement being done for + cached items; instead, we rely on our own prompt replacement logic + (:meth:`_get_prompt_replacements`) for the full text. """ mm_missing_counts = mm_missing_data_items.get_all_counts() @@ -1010,6 +1022,36 @@ def apply( mm_placeholders=mm_placeholders, ) + def _get_dummy_audios( + self, + *, + length: int, + num_audios: int, + ) -> list[npt.NDArray]: + audio = np.zeros((length, )) + return [audio] * num_audios + + def _get_dummy_images( + self, + *, + width: int, + height: int, + num_images: int, + ) -> list[Image.Image]: + image = Image.new("RGB", (width, height), color=0) + return [image] * num_images + + def _get_dummy_videos( + self, + *, + width: int, + height: int, + num_frames: int, + num_videos: int, + ) -> list[npt.NDArray]: + video = np.zeros((num_frames, width, height, 3)) + return [video] * num_videos + @abstractmethod def _get_dummy_mm_inputs( self, diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index 87b12a6fb33c1..7b6ded6a27084 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -400,15 +400,19 @@ def repeat_and_pad_placeholder_tokens( placeholder_token_idx = 0 for i, token in enumerate(prompt_token_ids): if token == placeholder_token_id: + curr_repeat_count = repeat_count[placeholder_token_idx] replacement_ids = repeat_and_pad_token( placeholder_token_id, - repeat_count=repeat_count[placeholder_token_idx], + repeat_count=curr_repeat_count, pad_token_left=pad_token_left, pad_token_right=pad_token_right, ) + offset = len(new_token_ids) + if pad_token_left is not None: + offset += 1 placeholder_ranges.append({ - "offset": len(new_token_ids), - "length": len(replacement_ids) + "offset": offset, + "length": curr_repeat_count, }) new_token_ids.extend(replacement_ids) placeholder_token_idx += 1 diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 509771b7e2e5a..a08a86d4007dc 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -647,10 +647,23 @@ def profile_run(self) -> None: self.mm_registry.get_max_tokens_per_item_by_modality( self.model_config).values()) - max_num_mm_items = min( + max_num_mm_items_encoder_budget = min( self.max_num_encoder_input_tokens, self.encoder_cache_size) // max_tokens_per_mm_item + max_mm_items_per_req = max( + self.mm_registry.get_mm_limits_per_prompt( + self.model_config).values()) + + # NOTE: We do not consider max_num_batched_tokens on purpose + # because the multimodal embeddings can be generated in advance + # and chunked prefilled. + max_num_mm_items_decoder_budget = self.max_num_reqs * \ + max_mm_items_per_req + + max_num_mm_items = min(max_num_mm_items_encoder_budget, + max_num_mm_items_decoder_budget) + # Dummy data definition in V0 may contain multiple multimodal items # (e.g, multiple images) for a single request, therefore here we # always replicate first item by max_num_mm_items times since in V1