From 5f940fc7d1974b9be17650b40098c3ca28f1d748 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 7 Jan 2025 16:57:28 +0000 Subject: [PATCH 1/7] Reorganize profiling/processing-related code Signed-off-by: DarkLight1337 --- .../processing/test_llava_next.py | 41 +- .../processing/test_llava_onevision.py | 41 +- tests/multimodal/test_processing.py | 52 +- .../vllm_add_dummy_model/my_llava.py | 10 +- vllm/inputs/preprocess.py | 2 +- vllm/inputs/registry.py | 4 +- vllm/model_executor/models/aria.py | 49 +- vllm/model_executor/models/blip2.py | 41 +- vllm/model_executor/models/chameleon.py | 47 +- vllm/model_executor/models/fuyu.py | 82 +-- vllm/model_executor/models/llava.py | 179 ++--- vllm/model_executor/models/llava_next.py | 52 +- .../model_executor/models/llava_next_video.py | 111 +-- vllm/model_executor/models/llava_onevision.py | 136 ++-- vllm/model_executor/models/phi3v.py | 81 +-- vllm/model_executor/models/qwen2_audio.py | 50 +- vllm/model_executor/models/qwen2_vl.py | 106 +-- vllm/model_executor/models/ultravox.py | 46 +- vllm/multimodal/processing.py | 686 ++---------------- vllm/multimodal/processor.py | 562 ++++++++++++++ vllm/multimodal/profiler.py | 133 ++++ vllm/multimodal/profiling.py | 67 +- vllm/multimodal/registry.py | 71 +- 23 files changed, 1363 insertions(+), 1286 deletions(-) create mode 100644 vllm/multimodal/processor.py create mode 100644 vllm/multimodal/profiler.py diff --git a/tests/models/decoder_only/vision_language/processing/test_llava_next.py b/tests/models/decoder_only/vision_language/processing/test_llava_next.py index 9fa6a8a10a0f9..737a8c8c78d76 100644 --- a/tests/models/decoder_only/vision_language/processing/test_llava_next.py +++ b/tests/models/decoder_only/vision_language/processing/test_llava_next.py @@ -4,24 +4,17 @@ import pytest from PIL import Image from pqdm.threads import pqdm -from transformers import AutoTokenizer -from vllm.inputs import InputProcessingContext +from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.parse import ImageSize +from vllm.multimodal.processor import BaseMultiModalProcessor +from vllm.multimodal.utils import cached_get_tokenizer from ....utils import build_model_context -# Fixtures lazy import to avoid initializing CUDA during test collection -@pytest.fixture() -def processor_for_llava_next(): - from vllm.model_executor.models.llava_next import ( - LlavaNextMultiModalProcessor) - return LlavaNextMultiModalProcessor - - def _validate_image_prompt_replacements_one( - processor, + processor: BaseMultiModalProcessor, num_imgs: int, failed_size_excs: list[tuple[ImageSize, Exception]], image_size: ImageSize, @@ -78,20 +71,17 @@ def _test_image_prompt_replacements( @pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"]) @pytest.mark.parametrize("num_imgs", [1, 2]) -def test_processor_prompt_replacements_regression( - processor_for_llava_next, - model_id: str, - num_imgs: int, -): +def test_processor_prompt_replacements_regression(model_id, num_imgs): ctx = build_model_context( model_name=model_id, tokenizer_name=model_id, mm_processor_kwargs=None, limit_mm_per_prompt={"image": num_imgs}, ) - tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) - ctx = InputProcessingContext(ctx.model_config, tokenizer) - processor = processor_for_llava_next(ctx) + processor = MULTIMODAL_REGISTRY.create_processor( + ctx.model_config, + tokenizer=cached_get_tokenizer(ctx.model_config.tokenizer), + ) image_ratios = [(171, 152), (184, 161), (198, 176), (333, 296), (369, 328), (488, 183), (2560, 1669)] @@ -111,20 +101,17 @@ def test_processor_prompt_replacements_regression( "Comment this out to run it manually.") @pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"]) @pytest.mark.parametrize("num_imgs", [1]) -def test_processor_prompt_replacements_all( - processor_for_llava_next, - model_id: str, - num_imgs: int, -): +def test_processor_prompt_replacements_all(model_id, num_imgs): ctx = build_model_context( model_name=model_id, tokenizer_name=model_id, mm_processor_kwargs=None, limit_mm_per_prompt={"image": num_imgs}, ) - tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) - ctx = InputProcessingContext(ctx.model_config, tokenizer) - processor = processor_for_llava_next(ctx) + processor = MULTIMODAL_REGISTRY.create_processor( + ctx.model_config, + tokenizer=cached_get_tokenizer(ctx.model_config.tokenizer), + ) seen_aspect_ratios = set[float]() image_sizes = list[ImageSize]() diff --git a/tests/models/decoder_only/vision_language/processing/test_llava_onevision.py b/tests/models/decoder_only/vision_language/processing/test_llava_onevision.py index d4cdffa210b6d..21765932be2a6 100644 --- a/tests/models/decoder_only/vision_language/processing/test_llava_onevision.py +++ b/tests/models/decoder_only/vision_language/processing/test_llava_onevision.py @@ -4,24 +4,17 @@ import pytest from PIL import Image from pqdm.threads import pqdm -from transformers import AutoTokenizer -from vllm.inputs import InputProcessingContext +from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.parse import ImageSize +from vllm.multimodal.processor import BaseMultiModalProcessor +from vllm.multimodal.utils import cached_get_tokenizer from ....utils import build_model_context -# Fixtures lazy import to avoid initializing CUDA during test collection -@pytest.fixture() -def processor_for_llava_onevision(): - from vllm.model_executor.models.llava_onevision import ( - LlavaOnevisionMultiModalProcessor) - return LlavaOnevisionMultiModalProcessor - - def _validate_image_prompt_replacements_one( - processor, + processor: BaseMultiModalProcessor, num_imgs: int, failed_size_excs: list[tuple[ImageSize, Exception]], image_size: ImageSize, @@ -77,20 +70,17 @@ def _test_image_prompt_replacements( @pytest.mark.parametrize("model_id", ["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"]) @pytest.mark.parametrize("num_imgs", [1, 2]) -def test_processor_prompt_replacements_regression( - processor_for_llava_onevision, - model_id: str, - num_imgs: int, -): +def test_processor_prompt_replacements_regression(model_id, num_imgs): ctx = build_model_context( model_name=model_id, tokenizer_name=model_id, mm_processor_kwargs=None, limit_mm_per_prompt={"image": num_imgs}, ) - tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) - ctx = InputProcessingContext(ctx.model_config, tokenizer) - processor = processor_for_llava_onevision(ctx) + processor = MULTIMODAL_REGISTRY.create_processor( + ctx.model_config, + tokenizer=cached_get_tokenizer(ctx.model_config.tokenizer), + ) image_ratios = [(171, 152), (184, 161), (198, 176), (333, 296), (369, 328), (488, 183), (2560, 1669)] @@ -111,20 +101,17 @@ def test_processor_prompt_replacements_regression( @pytest.mark.parametrize("model_id", ["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"]) @pytest.mark.parametrize("num_imgs", [1]) -def test_processor_prompt_replacements_all( - processor_for_llava_onevision, - model_id: str, - num_imgs: int, -): +def test_processor_prompt_replacements_all(model_id, num_imgs): ctx = build_model_context( model_name=model_id, tokenizer_name=model_id, mm_processor_kwargs=None, limit_mm_per_prompt={"image": num_imgs}, ) - tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) - ctx = InputProcessingContext(ctx.model_config, tokenizer) - processor = processor_for_llava_onevision(ctx) + processor = MULTIMODAL_REGISTRY.create_processor( + ctx.model_config, + tokenizer=cached_get_tokenizer(ctx.model_config.tokenizer), + ) seen_aspect_ratios = set[float]() image_sizes = list[ImageSize]() diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index 75d878217b657..003cc177c57c4 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -10,12 +10,17 @@ from vllm.config import ModelConfig from vllm.inputs import InputProcessingContext from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.processing import (ProcessingCache, PromptReplacement, - _PlaceholderInfo, find_mm_placeholders, +# yapf conflicts with isort for this block +# yapf: disable +from vllm.multimodal.processing import (PlaceholderInfo, ProcessingCache, + PromptReplacement, + find_mm_placeholders, find_text_matches, find_token_matches, iter_token_matches, replace_text_matches, replace_token_matches) +# yapf: enable +from vllm.multimodal.profiler import MultiModalProfiler from vllm.multimodal.utils import cached_get_tokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import full_groupby @@ -431,7 +436,7 @@ def test_find_replace_tokens( [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918], { "pattern_1": [ - _PlaceholderInfo( + PlaceholderInfo( modality="pattern_1", item_idx=0, start_idx=6, @@ -445,13 +450,13 @@ def test_find_replace_tokens( [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550], { "pattern_1": [ - _PlaceholderInfo( + PlaceholderInfo( modality="pattern_1", item_idx=0, start_idx=1, replacement=[32000, 32000], ), - _PlaceholderInfo( + PlaceholderInfo( modality="pattern_1", item_idx=1, start_idx=5, @@ -459,7 +464,7 @@ def test_find_replace_tokens( ), ], "pattern_3": [ - _PlaceholderInfo( + PlaceholderInfo( modality="pattern_3", item_idx=0, start_idx=7, @@ -472,13 +477,13 @@ def test_find_replace_tokens( [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550], { "pattern_1": [ - _PlaceholderInfo( + PlaceholderInfo( modality="pattern_1", item_idx=0, start_idx=1, replacement=[32000, 32000], ), - _PlaceholderInfo( + PlaceholderInfo( modality="pattern_1", item_idx=1, start_idx=3, @@ -486,7 +491,7 @@ def test_find_replace_tokens( ), ], "pattern_3": [ - _PlaceholderInfo( + PlaceholderInfo( modality="pattern_3", item_idx=0, start_idx=6, @@ -577,19 +582,15 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid): revision=None, limit_mm_per_prompt=limit_mm_per_prompt, ) - model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) - processor_factory = MULTIMODAL_REGISTRY._processor_factories[model_cls] - ctx = InputProcessingContext( + processor = MULTIMODAL_REGISTRY.create_processor( model_config, tokenizer=cached_get_tokenizer(model_config.tokenizer), ) - - processor = processor_factory(ctx, cache=None) - profiler = processor.profiling_info + profiler = MultiModalProfiler(processor) mock_supported_mm_limits = MagicMock(return_value={"image": num_supported}) - profiler.get_supported_mm_limits = mock_supported_mm_limits + processor.info.get_supported_mm_limits = mock_supported_mm_limits if is_valid: exc_ctx = nullcontext() @@ -597,7 +598,7 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid): exc_ctx = pytest.raises(ValueError, match="this model only supports") with exc_ctx: - profiler.get_mm_limits() + profiler.get_dummy_data(model_config.max_model_len) @pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"]) @@ -620,16 +621,12 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid): revision=None, limit_mm_per_prompt=limit_mm_per_prompt, ) - model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) - processor_factory = MULTIMODAL_REGISTRY._processor_factories[model_cls] - ctx = InputProcessingContext( + processor = MULTIMODAL_REGISTRY.create_processor( model_config, tokenizer=cached_get_tokenizer(model_config.tokenizer), ) - processor = processor_factory(ctx, cache=None) - rng = np.random.RandomState(0) image = _rand_img(rng, min_wh=128, max_wh=256) if num_images == 0: @@ -681,9 +678,9 @@ def _test_processing_cache_correctness( hf_overrides=hf_overrides, limit_mm_per_prompt=limit_mm_per_prompt, ) - model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) - processor_factory = MULTIMODAL_REGISTRY._processor_factories[model_cls] + model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) + factories = MULTIMODAL_REGISTRY._processor_factories[model_cls] ctx = InputProcessingContext( model_config, tokenizer=cached_get_tokenizer(model_config.tokenizer), @@ -691,8 +688,9 @@ def _test_processing_cache_correctness( # Ensure that it can fit all of the data cache = ProcessingCache(capacity=1 << 30) - baseline_processor = processor_factory(ctx, cache=None) - cached_processor = processor_factory(ctx, cache=cache) + baseline_processor = factories.build_processor(ctx, cache=None) + cached_processor = factories.build_processor(ctx, cache=cache) + dummy_data_builder = baseline_processor.dummy_data_builder rng = np.random.RandomState(0) @@ -724,7 +722,7 @@ def _test_processing_cache_correctness( } mm_counts = {k: len(vs) for k, vs in mm_data.items()} - prompt = baseline_processor.profiling_info.get_dummy_processor_inputs( + prompt = dummy_data_builder.get_dummy_processor_inputs( model_config.max_model_len, mm_counts, ).prompt_text diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py index 06dfebbb95527..e273c4cbf2ea2 100644 --- a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py @@ -2,13 +2,17 @@ import torch -from vllm.model_executor.models.llava import (LlavaForConditionalGeneration, - LlavaMultiModalProcessor) +from vllm.model_executor.models.llava import (LlavaDummyDataBuilder, + LlavaForConditionalGeneration, + LlavaMultiModalProcessor, + LlavaProcessingInfo) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -@MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor) +@MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor, + info=LlavaProcessingInfo, + dummy_data=LlavaDummyDataBuilder) class MyLlava(LlavaForConditionalGeneration): def compute_logits( diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index b362ee0cac328..6ddc1eb76f10d 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -7,7 +7,7 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry -from vllm.multimodal.processing import MultiModalDataDict, MultiModalInputsV2 +from vllm.multimodal.inputs import MultiModalDataDict, MultiModalInputsV2 from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.utils import print_info_once, print_warning_once diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 2d9d024e03e80..3fa9a8c14842d 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -323,6 +323,7 @@ def dummy_data_for_profiling( # Avoid circular import from vllm.model_executor.model_loader import get_model_architecture from vllm.multimodal import MultiModalKwargs + from vllm.multimodal.profiler import MultiModalProfiler from vllm.multimodal.utils import cached_get_tokenizer if mm_registry.has_processor(model_config): @@ -331,7 +332,8 @@ def dummy_data_for_profiling( trust_remote_code=model_config.trust_remote_code, ) processor = mm_registry.create_processor(model_config, tokenizer) - dummy_data = processor.get_dummy_data(seq_len) + profiler = MultiModalProfiler(processor) + dummy_data = profiler.get_dummy_data(seq_len) else: model_cls, _ = get_model_architecture(model_config) if is_encoder_data: diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index 2e649f10c0765..88cf73d109ee2 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -23,10 +23,10 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, NestedTensors) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - MultiModalDataItems, ProcessingMixin, - PromptReplacement) -from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs +from vllm.multimodal.parse import MultiModalDataItems +from vllm.multimodal.processing import BaseProcessingInfo, PromptReplacement +from vllm.multimodal.processor import BaseMultiModalProcessor +from vllm.multimodal.profiling import BaseDummyDataBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.aria import (AriaMoELMConfig, AriaVisionConfig) @@ -445,33 +445,33 @@ def build_mm_projector(config: PretrainedConfig): ) -class AriaProcessingMixin(ProcessingMixin): +class AriaProcessingInfo(BaseProcessingInfo): - def _get_hf_config(self): + def get_hf_config(self): return self.ctx.get_hf_config() - def _get_vision_config(self) -> AriaVisionConfig: - return self._get_hf_config().vision_config - - def _get_num_image_tokens(self) -> int: - hf_config = self._get_hf_config() - return max(hf_config.projector_patch_to_query_dict.values()) - - -class AriaProfilingInfo(AriaProcessingMixin, BaseProfilingInfo): + def get_vision_config(self) -> AriaVisionConfig: + return self.get_hf_config().vision_config def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: - return {"image": self._get_num_image_tokens()} + return {"image": self.get_num_image_tokens()} + + def get_num_image_tokens(self) -> int: + hf_config = self.get_hf_config() + return max(hf_config.projector_patch_to_query_dict.values()) + + +class AriaDummyDataBuilder(BaseDummyDataBuilder[AriaProcessingInfo]): def get_dummy_processor_inputs( self, seq_len: int, mm_counts: Mapping[str, int], ) -> ProcessorInputs: - vision_config = self._get_vision_config() + vision_config = self.info.get_vision_config() max_image_size = vision_config.image_size num_images = mm_counts.get("image", 0) @@ -483,7 +483,7 @@ def get_dummy_processor_inputs( num_images=num_images) } - hf_processor = self._get_hf_processor() + hf_processor = self.info.get_hf_processor() image_token: str = hf_processor.image_token # type: ignore return ProcessorInputs( @@ -492,10 +492,7 @@ def get_dummy_processor_inputs( ) -class AriaMultiModalProcessor(AriaProcessingMixin, BaseMultiModalProcessor): - - def _get_profiling_info(self) -> BaseProfilingInfo: - return AriaProfilingInfo(self.ctx) +class AriaMultiModalProcessor(BaseMultiModalProcessor[AriaProcessingInfo]): def _get_mm_fields_config( self, @@ -513,10 +510,10 @@ def _get_prompt_replacements( hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: - hf_config = self._get_hf_config() + hf_config = self.info.get_hf_config() image_token_id = hf_config.image_token_index - num_image_tokens = self._get_num_image_tokens() + num_image_tokens = self.info.get_num_image_tokens() return [ PromptReplacement( @@ -527,7 +524,9 @@ def _get_prompt_replacements( ] -@MULTIMODAL_REGISTRY.register_processor(AriaMultiModalProcessor) +@MULTIMODAL_REGISTRY.register_processor(AriaMultiModalProcessor, + info=AriaProcessingInfo, + dummy_data=AriaDummyDataBuilder) class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): """ Aria model for conditional generation tasks. diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index fd45783f167b4..5db1af556ce92 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -17,10 +17,10 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalInputsV2, MultiModalKwargs, NestedTensors, PlaceholderRange) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - MultiModalDataItems, ProcessingMixin, - PromptReplacement) -from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs +from vllm.multimodal.parse import MultiModalDataItems +from vllm.multimodal.processing import BaseProcessingInfo, PromptReplacement +from vllm.multimodal.processor import BaseMultiModalProcessor +from vllm.multimodal.profiling import BaseDummyDataBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from .blip import BlipVisionModel @@ -397,30 +397,30 @@ def forward( return sequence_output -class Blip2ProcessingMixin(ProcessingMixin): +class Blip2ProcessingInfo(BaseProcessingInfo): - def _get_hf_config(self): + def get_hf_config(self): return self.ctx.get_hf_config(Blip2Config) - def _get_num_image_tokens(self) -> int: - hf_config = self._get_hf_config() - return hf_config.num_query_tokens - - -class Blip2ProfilingInfo(Blip2ProcessingMixin, BaseProfilingInfo): - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": 1} def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: - return {"image": self._get_num_image_tokens()} + return {"image": self.get_num_image_tokens()} + + def get_num_image_tokens(self) -> int: + hf_config = self.get_hf_config() + return hf_config.num_query_tokens + + +class Blip2DummyDataBuilder(BaseDummyDataBuilder[Blip2ProcessingInfo]): def get_dummy_processor_inputs( self, seq_len: int, mm_counts: Mapping[str, int], ) -> ProcessorInputs: - hf_config = self._get_hf_config() + hf_config = self.info.get_hf_config() vision_config = hf_config.vision_config max_image_size = vision_config.image_size @@ -439,10 +439,7 @@ def get_dummy_processor_inputs( ) -class Blip2MultiModalProcessor(Blip2ProcessingMixin, BaseMultiModalProcessor): - - def _get_profiling_info(self) -> BaseProfilingInfo: - return Blip2ProfilingInfo(self.ctx) +class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]): def _get_mm_fields_config( self, @@ -460,7 +457,7 @@ def _get_prompt_replacements( hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: - num_image_tokens = self._get_num_image_tokens() + num_image_tokens = self.info.get_num_image_tokens() return [ PromptReplacement( @@ -491,7 +488,9 @@ def apply( return result -@MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor) +@MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor, + info=Blip2ProcessingInfo, + dummy_data=Blip2DummyDataBuilder) class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 73ed73b61ebf9..29cb489d58a2e 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -30,10 +30,10 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalInputsV2, MultiModalKwargs, NestedTensors, PlaceholderRange) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - MultiModalDataItems, ProcessingMixin, - PromptReplacement) -from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs +from vllm.multimodal.parse import MultiModalDataItems +from vllm.multimodal.processing import BaseProcessingInfo, PromptReplacement +from vllm.multimodal.processor import BaseMultiModalProcessor +from vllm.multimodal.profiling import BaseDummyDataBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.utils import print_warning_once @@ -49,33 +49,33 @@ class ChameleonImagePixelInputs(TypedDict): """Shape: `(batch_size * num_images, num_channels, height, width)`""" -class ChameleonProcessingMixin(ProcessingMixin): +class ChameleonProcessingInfo(BaseProcessingInfo): - def _get_hf_config(self): + def get_hf_config(self): return self.ctx.get_hf_config(ChameleonConfig) - def _get_hf_processor(self): + def get_hf_processor(self): return self.ctx.get_hf_processor(ChameleonProcessor) - def _get_num_image_tokens(self) -> int: - processor = self._get_hf_processor() - return processor.image_seq_length - - -class ChameleonProfilingInfo(ChameleonProcessingMixin, BaseProfilingInfo): - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": 1} def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: - return {"image": self._get_num_image_tokens()} + return {"image": self.get_num_image_tokens()} + + def get_num_image_tokens(self) -> int: + processor = self.get_hf_processor() + return processor.image_seq_length + + +class ChameleonDummyDataBuilder(BaseDummyDataBuilder[ChameleonProcessingInfo]): def get_dummy_processor_inputs( self, seq_len: int, mm_counts: Mapping[str, int], ) -> ProcessorInputs: - config = self._get_hf_config() + config = self.info.get_hf_config() width = height = config.vq_config.resolution num_images = mm_counts.get("image", 0) @@ -93,11 +93,8 @@ def get_dummy_processor_inputs( ) -class ChameleonMultiModalProcessor(ChameleonProcessingMixin, - BaseMultiModalProcessor): - - def _get_profiling_info(self) -> BaseProfilingInfo: - return ChameleonProfilingInfo(self.ctx) +class ChameleonMultiModalProcessor( + BaseMultiModalProcessor[ChameleonProcessingInfo]): def _get_mm_fields_config( self, @@ -112,7 +109,7 @@ def _get_prompt_replacements( hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: - processor = self._get_hf_processor(**hf_processor_mm_kwargs) + processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) return [ PromptReplacement( @@ -120,7 +117,7 @@ def _get_prompt_replacements( target="", replacement="".join([ processor.image_start_token, - processor.image_token * self._get_num_image_tokens(), + processor.image_token * self.info.get_num_image_tokens(), processor.image_end_token, ]), ) @@ -916,7 +913,9 @@ def forward( return hidden_states -@MULTIMODAL_REGISTRY.register_processor(ChameleonMultiModalProcessor) +@MULTIMODAL_REGISTRY.register_processor(ChameleonMultiModalProcessor, + info=ChameleonProcessingInfo, + dummy_data=ChameleonDummyDataBuilder) class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index c937fcb0978b9..972d47c1633c8 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -33,11 +33,11 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalInputsV2, MultiModalKwargs, NestedTensors, PlaceholderRange) -from vllm.multimodal.parse import ImageProcessorItems, ImageSize -from vllm.multimodal.processing import (BaseMultiModalProcessor, - MultiModalDataItems, ProcessingMixin, - PromptReplacement) -from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs +from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, + MultiModalDataItems) +from vllm.multimodal.processing import BaseProcessingInfo, PromptReplacement +from vllm.multimodal.processor import BaseMultiModalProcessor +from vllm.multimodal.profiling import BaseDummyDataBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from .interfaces import SupportsMultiModal, SupportsPP @@ -64,24 +64,38 @@ class FuyuImagePatchInputs(TypedDict): """ -class FuyuProcessingMixin(ProcessingMixin): +class FuyuProcessingInfo(BaseProcessingInfo): - def _get_hf_config(self): + def get_hf_config(self): return self.ctx.get_hf_config(FuyuConfig) - def _get_hf_processor(self): + def get_hf_processor(self): return self.ctx.get_hf_processor(FuyuProcessor) - def _get_image_processor(self) -> FuyuImageProcessor: - return self._get_hf_processor().image_processor + def get_image_processor(self) -> FuyuImageProcessor: + return self.get_hf_processor().image_processor - def _get_image_feature_grid_size( + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": 1} + + def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: + target_width, target_height = self.get_image_size_with_most_features() + + max_ncols, max_nrows = self.get_image_feature_grid_size( + image_width=target_width, + image_height=target_height, + ) + max_image_tokens = (max_ncols + 1) * max_nrows + + return {"image": max_image_tokens} + + def get_image_feature_grid_size( self, *, image_width: int, image_height: int, ) -> tuple[int, int]: - image_processor = self._get_image_processor() + image_processor = self.get_image_processor() target_width = image_processor.size["width"] target_height = image_processor.size["height"] @@ -97,34 +111,21 @@ def _get_image_feature_grid_size( nrows = math.ceil(image_height / 30) return ncols, nrows - -class FuyuProfilingInfo(FuyuProcessingMixin, BaseProfilingInfo): - - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: - return {"image": 1} - - def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: - target_width, target_height = self._get_image_size_with_most_features() - - max_ncols, max_nrows = self._get_image_feature_grid_size( - image_width=target_width, - image_height=target_height, - ) - max_image_tokens = (max_ncols + 1) * max_nrows - - return {"image": max_image_tokens} - - def _get_image_size_with_most_features(self) -> ImageSize: - image_processor = self._get_image_processor() + def get_image_size_with_most_features(self) -> ImageSize: + image_processor = self.get_image_processor() return ImageSize(width=image_processor.size["width"], height=image_processor.size["height"]) + +class FuyuDummyDataBuilder(BaseDummyDataBuilder[FuyuProcessingInfo]): + def get_dummy_processor_inputs( self, seq_len: int, mm_counts: Mapping[str, int], ) -> ProcessorInputs: - target_width, target_height = self._get_image_size_with_most_features() + target_width, target_height = \ + self.info.get_image_size_with_most_features() num_images = mm_counts.get("image", 0) mm_data = { @@ -140,10 +141,7 @@ def get_dummy_processor_inputs( ) -class FuyuMultiModalProcessor(FuyuProcessingMixin, BaseMultiModalProcessor): - - def _get_profiling_info(self) -> BaseProfilingInfo: - return FuyuProfilingInfo(self.ctx) +class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]): def _call_hf_processor( self, @@ -156,7 +154,7 @@ def _call_hf_processor( # Avoid warning from HF logger for text-only input # Input_ids format: bos_token_id + prompt_token_ids + boa_token_id # Tokenizer won't add boa_token_id by default, we add it manually. - tokenizer = self._get_tokenizer() + tokenizer = self.info.get_tokenizer() boa_token_id: int = tokenizer.vocab["<0x04>"] # type: ignore prompt_ids = tokenizer.encode(prompt) + [boa_token_id] return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") @@ -196,10 +194,10 @@ def _get_prompt_replacements( hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: - hf_config = self._get_hf_config() + hf_config = self.info.get_hf_config() bos_token_id = hf_config.bos_token_id - tokenizer = self._get_tokenizer() + tokenizer = self.info.get_tokenizer() eot_token_id = tokenizer.bos_token_id assert isinstance(eot_token_id, int) @@ -207,7 +205,7 @@ def get_replacement_fuyu(item_idx: int): images = mm_items.get_items("image", ImageProcessorItems) image_size = images.get_image_size(item_idx) - ncols, nrows = self._get_image_feature_grid_size( + ncols, nrows = self.info.get_image_feature_grid_size( image_width=image_size.width, image_height=image_size.height, ) @@ -244,7 +242,9 @@ def apply( return result -@MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor) +@MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor, + info=FuyuProcessingInfo, + dummy_data=FuyuDummyDataBuilder) class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 4299af8cd03a2..a2521c7e8514e 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -1,7 +1,7 @@ -from abc import ABC, abstractmethod +from abc import abstractmethod from functools import cached_property from typing import (Final, Iterable, List, Literal, Mapping, Optional, - Protocol, Set, Tuple, TypedDict, Union) + Protocol, Set, Tuple, TypedDict, TypeVar, Union) import torch import torch.nn as nn @@ -25,11 +25,11 @@ MultiModalInputsV2, MultiModalKwargs, NestedTensors) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, - ImageSize) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - MultiModalDataItems, ProcessingCache, - ProcessingMixin, PromptReplacement) -from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs + ImageSize, MultiModalDataItems) +from vllm.multimodal.processing import (BaseProcessingInfo, ProcessingCache, + PromptReplacement) +from vllm.multimodal.processor import BaseMultiModalProcessor +from vllm.multimodal.profiling import BaseDummyDataBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from .clip import CLIPVisionModel @@ -105,36 +105,25 @@ class LlavaLikeProcessor(Protocol): image_token: Final[str] -class BaseLlavaProcessingMixin(ProcessingMixin, ABC): +class BaseLlavaProcessingInfo(BaseProcessingInfo): - def _get_hf_config(self) -> LlavaLikeConfig: + def get_hf_config(self) -> LlavaLikeConfig: return self.ctx.get_hf_config(LlavaConfig) - def _get_vision_encoder_info(self): - return get_vision_encoder_info(self._get_hf_config()) + def get_vision_encoder_info(self): + return get_vision_encoder_info(self.get_hf_config()) @abstractmethod - def _get_hf_processor(self) -> LlavaLikeProcessor: + def get_hf_processor(self) -> LlavaLikeProcessor: raise NotImplementedError - def _get_num_image_tokens( - self, - *, - image_width: int, - image_height: int, - ) -> int: - hf_config = self._get_hf_config() - vision_encoder_info = self._get_vision_encoder_info() + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} - return self._apply_feature_select_strategy( - hf_config.vision_feature_select_strategy, - vision_encoder_info.get_num_image_tokens( - image_width=image_width, - image_height=image_height, - ), - ) + def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: + return {"image": self.get_max_image_tokens()} - def _apply_feature_select_strategy( + def apply_feature_select_strategy( self, strategy: str, encoder_num_image_tokens: int, @@ -147,28 +136,42 @@ def _apply_feature_select_strategy( msg = f"Unexpected feature select strategy: {strategy!r}" raise NotImplementedError(msg) + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + hf_config = self.get_hf_config() + vision_encoder_info = self.get_vision_encoder_info() -class BaseLlavaProfilingInfo(BaseLlavaProcessingMixin, BaseProfilingInfo): - - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: - return {"image": None} - - def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: - return {"image": self._get_max_image_tokens()} + return self.apply_feature_select_strategy( + hf_config.vision_feature_select_strategy, + vision_encoder_info.get_num_image_tokens( + image_width=image_width, + image_height=image_height, + ), + ) - def _get_image_size_with_most_features(self) -> ImageSize: - vision_encoder_info = self._get_vision_encoder_info() + def get_image_size_with_most_features(self) -> ImageSize: + vision_encoder_info = self.get_vision_encoder_info() width = height = vision_encoder_info.get_image_size() return ImageSize(width=width, height=height) - def _get_max_image_tokens(self) -> int: - target_width, target_height = self._get_image_size_with_most_features() + def get_max_image_tokens(self) -> int: + target_width, target_height = self.get_image_size_with_most_features() - return self._get_num_image_tokens( + return self.get_num_image_tokens( image_width=target_width, image_height=target_height, ) + +_I = TypeVar("_I", bound=BaseLlavaProcessingInfo) + + +class LlavaDummyDataBuilder(BaseDummyDataBuilder[_I]): + def get_dummy_processor_inputs( self, seq_len: int, @@ -176,9 +179,10 @@ def get_dummy_processor_inputs( ) -> ProcessorInputs: num_images = mm_counts.get("image", 0) - processor = self._get_hf_processor() + processor = self.info.get_hf_processor() image_token = processor.image_token - target_width, target_height = self._get_image_size_with_most_features() + target_width, target_height = \ + self.info.get_image_size_with_most_features() mm_data = { "image": @@ -193,23 +197,13 @@ def get_dummy_processor_inputs( ) -class LlavaProcessingMixin(BaseLlavaProcessingMixin): +class LlavaProcessingInfo(BaseLlavaProcessingInfo): - def _get_hf_processor(self): + def get_hf_processor(self): return self.ctx.get_hf_processor(LlavaProcessor) -class LlavaProfilingInfo(LlavaProcessingMixin, BaseLlavaProfilingInfo): - pass - - -class BaseLlavaMultiModalProcessor(LlavaProcessingMixin, - BaseMultiModalProcessor): - - # Copied from BaseMultiModalProcessor - @abstractmethod - def _get_profiling_info(self) -> BaseProfilingInfo: - raise NotImplementedError +class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]): # Copied from BaseMultiModalProcessor @abstractmethod @@ -226,7 +220,7 @@ def _get_prompt_replacements( hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: - hf_config = self._get_hf_config() + hf_config = self.info.get_hf_config() image_token_id = hf_config.image_token_index def get_replacement(item_idx: int): @@ -237,7 +231,7 @@ def get_replacement(item_idx: int): num_image_tokens = images.get_feature_size(item_idx) else: image_size = images.get_image_size(item_idx) - num_image_tokens = self._get_num_image_tokens( + num_image_tokens = self.info.get_num_image_tokens( image_width=image_size.width, image_height=image_size.height, ) @@ -253,10 +247,8 @@ def get_replacement(item_idx: int): ] -class LlavaMultiModalProcessor(BaseLlavaMultiModalProcessor): - - def _get_profiling_info(self) -> BaseProfilingInfo: - return LlavaProfilingInfo(self.ctx) +class LlavaMultiModalProcessor( + BaseLlavaMultiModalProcessor[LlavaProcessingInfo]): def _get_mm_fields_config( self, @@ -269,21 +261,14 @@ def _get_mm_fields_config( ) -class PixtralHFProcessingMixin(BaseLlavaProcessingMixin): +class PixtralHFProcessingInfo(BaseLlavaProcessingInfo): - def _get_hf_processor(self): + def get_hf_processor(self): return self.ctx.get_hf_processor(PixtralProcessor) -class PixtralHFProfilingInfo(PixtralHFProcessingMixin, BaseLlavaProfilingInfo): - pass - - -class PixtralHFMultiModalProcessor(PixtralHFProcessingMixin, - BaseMultiModalProcessor): - - def _get_profiling_info(self) -> BaseProfilingInfo: - return PixtralHFProfilingInfo(self.ctx) +class PixtralHFMultiModalProcessor( + BaseMultiModalProcessor[PixtralHFProcessingInfo]): def _call_hf_processor( self, @@ -328,10 +313,10 @@ def _get_prompt_replacements( hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: - hf_config = self._get_hf_config() + hf_config = self.info.get_hf_config() image_token_id = hf_config.image_token_index - processor = self._get_hf_processor() + processor = self.info.get_hf_processor() image_token = processor.image_token image_break_token = processor.image_break_token image_end_token = processor.image_end_token @@ -363,26 +348,40 @@ def get_replacement(item_idx: int): ] +def _build_llava_or_pixtral_hf_info( + ctx: InputProcessingContext, ) -> BaseLlavaProcessingInfo: + hf_config = ctx.get_hf_config(LlavaConfig) + + if isinstance(hf_config.vision_config, PixtralVisionConfig): + return PixtralHFProcessingInfo(ctx) + + return LlavaProcessingInfo(ctx) + + def _build_llava_or_pixtral_hf_processor( - ctx: InputProcessingContext, + info: _I, + dummy_data_builder: BaseDummyDataBuilder[_I], *, cache: Optional[ProcessingCache] = None, enable_sanity_checks: bool = True, ) -> BaseMultiModalProcessor: - hf_config = ctx.get_hf_config(LlavaConfig) - - if isinstance(hf_config.vision_config, PixtralVisionConfig): + if isinstance(info, PixtralHFProcessingInfo): return PixtralHFMultiModalProcessor( - ctx, + info, + dummy_data_builder, # type: ignore + cache=cache, + enable_sanity_checks=enable_sanity_checks, + ) + + if isinstance(info, LlavaProcessingInfo): + return LlavaMultiModalProcessor( + info, + dummy_data_builder, # type: ignore cache=cache, enable_sanity_checks=enable_sanity_checks, ) - return LlavaMultiModalProcessor( - ctx, - cache=cache, - enable_sanity_checks=enable_sanity_checks, - ) + raise NotImplementedError(type(info)) def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int: @@ -460,7 +459,9 @@ def init_vision_tower_for_llava( raise NotImplementedError(msg) -@MULTIMODAL_REGISTRY.register_processor(_build_llava_or_pixtral_hf_processor) +@MULTIMODAL_REGISTRY.register_processor(_build_llava_or_pixtral_hf_processor, + info=_build_llava_or_pixtral_hf_info, + dummy_data=LlavaDummyDataBuilder) class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): # BitandBytes specific attributes bitsandbytes_stacked_params_mapping = { @@ -721,11 +722,11 @@ def apply( mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], ) -> MultiModalInputsV2: - hf_config = self._get_hf_config() + hf_config = self.info.get_hf_config() image_token_id = hf_config.image_token_index # Assume that it doesn't depend on the image size - num_image_tokens = self._get_num_image_tokens( + num_image_tokens = self.info.get_num_image_tokens( image_width=-1, image_height=-1, ) @@ -790,6 +791,8 @@ def get_replacement_mantis(item_idx: int): # To use this model, please use # `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` -@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor) +@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor, + info=LlavaProcessingInfo, + dummy_data=LlavaDummyDataBuilder) class MantisForConditionalGeneration(LlavaForConditionalGeneration): pass diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 815456dac2a2f..6af8acd392e5e 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -1,6 +1,6 @@ from functools import cached_property from typing import (Final, Iterable, List, Literal, Mapping, Optional, - Protocol, Set, Tuple, TypedDict, Union) + Protocol, Set, Tuple, TypedDict, TypeVar, Union) import torch import torch.nn as nn @@ -16,13 +16,12 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalFieldConfig, NestedTensors from vllm.multimodal.parse import ImageSize -from vllm.multimodal.profiling import BaseProfilingInfo from vllm.sequence import IntermediateTensors from .clip import CLIPVisionModel from .interfaces import SupportsMultiModal, SupportsPP -from .llava import (BaseLlavaMultiModalProcessor, BaseLlavaProcessingMixin, - BaseLlavaProfilingInfo, LlavaLikeConfig, +from .llava import (BaseLlavaMultiModalProcessor, BaseLlavaProcessingInfo, + LlavaDummyDataBuilder, LlavaLikeConfig, LlavaMultiModalProjector, init_vision_tower_for_llava) from .siglip import SiglipVisionModel from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn, @@ -65,25 +64,25 @@ class LlavaNextLikeConfig(LlavaLikeConfig, Protocol): image_grid_pinpoints: Final[list[list[int]]] -class LlavaNextProcessingMixin(BaseLlavaProcessingMixin): +class LlavaNextProcessingInfo(BaseLlavaProcessingInfo): - def _get_hf_config(self) -> LlavaNextLikeConfig: + def get_hf_config(self) -> LlavaNextLikeConfig: return self.ctx.get_hf_config(LlavaNextConfig) - def _get_hf_processor(self): + def get_hf_processor(self): return self.ctx.get_hf_processor(LlavaNextProcessor) # Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L113 - def _get_num_image_tokens( + def get_num_image_tokens( self, *, image_width: int, image_height: int, ) -> int: - hf_config = self._get_hf_config() - vision_encoder_info = self._get_vision_encoder_info() + hf_config = self.get_hf_config() + vision_encoder_info = self.get_vision_encoder_info() - base_feature_size = self._apply_feature_select_strategy( + base_feature_size = self.apply_feature_select_strategy( hf_config.vision_feature_select_strategy, vision_encoder_info.get_num_image_tokens( image_width=image_width, @@ -100,7 +99,7 @@ def _get_num_image_tokens( ( unpadded_feature_size, newline_feature_size, - ) = self._get_num_unpadded_features( + ) = self.get_num_unpadded_features( original_height=image_height, original_width=image_width, npatches=vision_encoder_info.get_patch_grid_length(), @@ -111,7 +110,7 @@ def _get_num_image_tokens( return unpadded_feature_size + newline_feature_size + base_feature_size # Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86 - def _get_num_unpadded_features( + def get_num_unpadded_features( self, *, original_height: int, @@ -140,16 +139,13 @@ def _get_num_unpadded_features( return (unpadded_features, newline_features) - -class LlavaNextProfilingInfo(LlavaNextProcessingMixin, BaseLlavaProfilingInfo): - - def _get_image_size_with_most_features(self) -> ImageSize: - hf_config = self._get_hf_config() + def get_image_size_with_most_features(self) -> ImageSize: + hf_config = self.get_hf_config() largest_feature_size, largest_feature_pinpoint = 0, None for (height, width) in hf_config.image_grid_pinpoints: - feat_size = self._get_num_image_tokens(image_width=width, - image_height=height) + feat_size = self.get_num_image_tokens(image_width=width, + image_height=height) if feat_size > largest_feature_size: largest_feature_size = feat_size largest_feature_pinpoint = ImageSize(width=width, @@ -161,11 +157,10 @@ def _get_image_size_with_most_features(self) -> ImageSize: return largest_feature_pinpoint -class LlavaNextMultiModalProcessor(LlavaNextProcessingMixin, - BaseLlavaMultiModalProcessor): +_I = TypeVar("_I", bound=LlavaNextProcessingInfo) + - def _get_profiling_info(self) -> BaseProfilingInfo: - return LlavaNextProfilingInfo(self.ctx) +class BaseLlavaNextMultiModalProcessor(BaseLlavaMultiModalProcessor[_I]): def _get_mm_fields_config( self, @@ -179,7 +174,14 @@ def _get_mm_fields_config( ) -@MULTIMODAL_REGISTRY.register_processor(LlavaNextMultiModalProcessor) +class LlavaNextMultiModalProcessor( + BaseLlavaNextMultiModalProcessor[LlavaNextProcessingInfo]): + pass + + +@MULTIMODAL_REGISTRY.register_processor(LlavaNextMultiModalProcessor, + info=LlavaNextProcessingInfo, + dummy_data=LlavaDummyDataBuilder) class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index 6e82cee1c95a4..881d71ed9ab5c 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -17,12 +17,11 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, NestedTensors) -from vllm.multimodal.parse import (ImageSize, VideoEmbeddingItems, - VideoProcessorItems) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - MultiModalDataItems, ProcessingMixin, - PromptReplacement) -from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs +from vllm.multimodal.parse import (ImageSize, MultiModalDataItems, + VideoEmbeddingItems, VideoProcessorItems) +from vllm.multimodal.processing import BaseProcessingInfo, PromptReplacement +from vllm.multimodal.processor import BaseMultiModalProcessor +from vllm.multimodal.profiling import BaseDummyDataBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of @@ -47,77 +46,73 @@ class LlavaNextVideoPixelInputs(TypedDict): """ -class LlavaNextVideoProcessingMixin(ProcessingMixin): +class LlavaNextVideoProcessingInfo(BaseProcessingInfo): - def _get_hf_config(self): + def get_hf_config(self): return self.ctx.get_hf_config(LlavaNextVideoConfig) - def _get_vision_encoder_info(self): - return get_vision_encoder_info(self._get_hf_config()) + def get_vision_encoder_info(self): + return get_vision_encoder_info(self.get_hf_config()) - def _get_hf_processor(self): + def get_hf_processor(self): return self.ctx.get_hf_processor(LlavaNextVideoProcessor) - def _get_num_frame_tokens( + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"video": 1} + + def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: + target_width, target_height = self.get_image_size_with_most_features() + + max_video_tokens = self.get_num_video_tokens( + image_width=target_width, + image_height=target_height, + num_frames=self.get_max_num_frames(seq_len), + ) + + return {"video": max_video_tokens} + + def get_image_size_with_most_features(self) -> ImageSize: + vision_encoder_info = self.get_vision_encoder_info() + width = height = vision_encoder_info.get_image_size() + return ImageSize(width=width, height=height) + + def get_num_frame_tokens( self, *, image_width: int, image_height: int, ) -> int: - hf_config = self._get_hf_config() + hf_config = self.get_hf_config() spatial_pool_stride = hf_config.spatial_pool_stride - vision_encoder_info = self._get_vision_encoder_info() + vision_encoder_info = self.get_vision_encoder_info() patch_grid_length = vision_encoder_info.get_patch_grid_length() pooled_grid_length = math.ceil(patch_grid_length / spatial_pool_stride) return pooled_grid_length * pooled_grid_length - def _get_num_video_tokens( + def get_num_video_tokens( self, *, image_width: int, image_height: int, num_frames: int, ) -> int: - num_frame_tokens = self._get_num_frame_tokens( + num_frame_tokens = self.get_num_frame_tokens( image_width=image_width, image_height=image_height, ) return num_frame_tokens * num_frames - -class LlavaNextVideoProfilingInfo(LlavaNextVideoProcessingMixin, - BaseProfilingInfo): - - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: - return {"video": 1} - - def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: - target_width, target_height = self._get_image_size_with_most_features() - - max_video_tokens = self._get_num_video_tokens( - image_width=target_width, - image_height=target_height, - num_frames=self._get_dummy_num_frames(seq_len), - ) - - return {"video": max_video_tokens} - - def _get_image_size_with_most_features(self) -> ImageSize: - vision_encoder_info = self._get_vision_encoder_info() - width = height = vision_encoder_info.get_image_size() - return ImageSize(width=width, height=height) - - def _get_max_video_frames(self, max_tokens: int) -> int: - target_width, target_height = self._get_image_size_with_most_features() + def get_max_video_frames(self, max_tokens: int) -> int: + target_width, target_height = self.get_image_size_with_most_features() num_frames = 0 while True: next_num_frames = num_frames + 1 - next_max_tokens = self._get_num_video_tokens( + next_max_tokens = self.get_num_video_tokens( image_width=target_width, image_height=target_height, num_frames=next_num_frames, @@ -130,14 +125,18 @@ def _get_max_video_frames(self, max_tokens: int) -> int: return num_frames - def _get_dummy_num_frames(self, seq_len: int) -> int: + def get_max_num_frames(self, seq_len: int) -> int: mm_config = self.ctx.get_mm_config() max_videos = mm_config.limit_per_prompt.get("video", 1) - max_total_frames = self._get_max_video_frames(seq_len) + max_total_frames = self.get_max_video_frames(seq_len) return max(max_total_frames // max(max_videos, 1), 1) + +class LlavaNextVideoDummyDataBuilder( + BaseDummyDataBuilder[LlavaNextVideoProcessingInfo]): + def get_dummy_processor_inputs( self, seq_len: int, @@ -145,16 +144,17 @@ def get_dummy_processor_inputs( ) -> ProcessorInputs: num_videos = mm_counts.get("video", 0) - processor = self._get_hf_processor() + processor = self.info.get_hf_processor() video_token = processor.video_token - target_width, target_height = self._get_image_size_with_most_features() + target_width, target_height = \ + self.info.get_image_size_with_most_features() mm_data = { "video": self._get_dummy_videos( width=target_width, height=target_height, - num_frames=self._get_dummy_num_frames(seq_len), + num_frames=self.info.get_max_num_frames(seq_len), num_videos=num_videos, ) } @@ -165,11 +165,8 @@ def get_dummy_processor_inputs( ) -class LlavaNextVideoMultiModalProcessor(LlavaNextVideoProcessingMixin, - BaseMultiModalProcessor): - - def _get_profiling_info(self) -> BaseProfilingInfo: - return LlavaNextVideoProfilingInfo(self.ctx) +class LlavaNextVideoMultiModalProcessor( + BaseMultiModalProcessor[LlavaNextVideoProcessingInfo]): def _get_mm_fields_config( self, @@ -184,7 +181,7 @@ def _get_prompt_replacements( hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: - hf_config = self._get_hf_config() + hf_config = self.info.get_hf_config() video_token_id = hf_config.video_token_index def get_replacement(item_idx: int): @@ -195,7 +192,7 @@ def get_replacement(item_idx: int): num_video_tokens = videos.get_feature_size(item_idx) else: image_size = videos.get_frame_size(item_idx) - num_video_tokens = self._get_num_video_tokens( + num_video_tokens = self.info.get_num_video_tokens( image_width=image_size.width, image_height=image_size.height, num_frames=videos.get_num_frames(item_idx), @@ -269,7 +266,11 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor: return hidden_states -@MULTIMODAL_REGISTRY.register_processor(LlavaNextVideoMultiModalProcessor) +@MULTIMODAL_REGISTRY.register_processor( + LlavaNextVideoMultiModalProcessor, + info=LlavaNextVideoProcessingInfo, + dummy_data=LlavaNextVideoDummyDataBuilder, +) class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index b5e3edba1f01c..6622e3a150e64 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -17,19 +17,20 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors +from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, + NestedTensors) from vllm.multimodal.parse import (ImageSize, MultiModalDataItems, VideoEmbeddingItems, VideoProcessorItems) -from vllm.multimodal.processing import MultiModalFieldConfig, PromptReplacement -from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs +from vllm.multimodal.processing import PromptReplacement +from vllm.multimodal.profiling import ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of from .clip import CLIPVisionModel from .interfaces import SupportsMultiModal, SupportsPP -from .llava import BaseLlavaProfilingInfo, init_vision_tower_for_llava -from .llava_next import (LlavaNextLikeConfig, LlavaNextMultiModalProcessor, - LlavaNextProcessingMixin) +from .llava import LlavaDummyDataBuilder, init_vision_tower_for_llava +from .llava_next import (BaseLlavaNextMultiModalProcessor, LlavaNextLikeConfig, + LlavaNextProcessingInfo) from .siglip import SiglipVisionModel from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) @@ -89,17 +90,26 @@ class LlavaOnevisionLikeConfig(LlavaNextLikeConfig, Protocol): video_token_index: Final[int] -class LlavaOnevisionProcessingMixin(LlavaNextProcessingMixin): +class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo): - def _get_hf_config(self) -> LlavaOnevisionLikeConfig: + def get_hf_config(self) -> LlavaOnevisionLikeConfig: return self.ctx.get_hf_config(LlavaOnevisionConfig) - def _get_hf_processor(self): + def get_hf_processor(self): return self.ctx.get_hf_processor(LlavaOnevisionProcessor) + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None, "video": None} + + def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: + return { + "image": self.get_max_image_tokens(), + "video": self.get_max_video_tokens(seq_len), + } + # Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86 # with additional logic afterwards taken from LlavaOnevisionProcessor - def _get_num_unpadded_features( + def get_num_unpadded_features( self, *, original_height: int, @@ -135,72 +145,59 @@ def _get_num_unpadded_features( return (unpadded_features, newline_features) - def _get_num_frame_tokens( + def get_image_size_with_most_features(self) -> ImageSize: + hf_config = self.get_hf_config() + largest_feature_size, largest_feature_pinpoint = 0, None + for (height, width) in hf_config.image_grid_pinpoints: + feat_size = self.get_num_image_tokens(image_width=width, + image_height=height) + if feat_size > largest_feature_size: + largest_feature_size = feat_size + largest_feature_pinpoint = ImageSize(width=width, + height=height) + + if largest_feature_size == 0 or largest_feature_pinpoint is None: + raise ValueError("Cannot have a largest feature size of 0!") + + return largest_feature_pinpoint + + def get_num_frame_tokens( self, *, image_width: int, image_height: int, ) -> int: - hf_config = self._get_hf_config() + hf_config = self.get_hf_config() spatial_pool_stride = getattr(hf_config, "spatial_pool_stride", 2) - vision_encoder_info = self._get_vision_encoder_info() + vision_encoder_info = self.get_vision_encoder_info() patch_grid_length = vision_encoder_info.get_patch_grid_length() pooled_grid_length = math.ceil(patch_grid_length / spatial_pool_stride) return pooled_grid_length * pooled_grid_length - def _get_num_video_tokens( + def get_num_video_tokens( self, *, image_width: int, image_height: int, num_frames: int, ) -> int: - num_frame_tokens = self._get_num_frame_tokens( + num_frame_tokens = self.get_num_frame_tokens( image_width=image_width, image_height=image_height, ) return num_frame_tokens * num_frames + 1 # Newline token - -class LlavaOnevisionProfilingInfo(LlavaOnevisionProcessingMixin, - BaseLlavaProfilingInfo): - - def _get_image_size_with_most_features(self) -> ImageSize: - hf_config = self._get_hf_config() - largest_feature_size, largest_feature_pinpoint = 0, None - for (height, width) in hf_config.image_grid_pinpoints: - feat_size = self._get_num_image_tokens(image_width=width, - image_height=height) - if feat_size > largest_feature_size: - largest_feature_size = feat_size - largest_feature_pinpoint = ImageSize(width=width, - height=height) - - if largest_feature_size == 0 or largest_feature_pinpoint is None: - raise ValueError("Cannot have a largest feature size of 0!") - - return largest_feature_pinpoint - - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: - return {"image": None, "video": None} - - def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: - return { - "image": self._get_max_image_tokens(), - "video": self._get_max_video_tokens(seq_len), - } - - def _get_max_video_frames(self, max_tokens: int) -> int: - target_width, target_height = self._get_image_size_with_most_features() + def get_max_video_frames(self, max_tokens: int) -> int: + target_width, target_height = self.get_image_size_with_most_features() num_frames = 0 while True: next_num_frames = num_frames + 1 - next_max_tokens = self._get_num_video_tokens( + next_max_tokens = self.get_num_video_tokens( image_width=target_width, image_height=target_height, num_frames=next_num_frames, @@ -213,28 +210,32 @@ def _get_max_video_frames(self, max_tokens: int) -> int: return num_frames - def _get_dummy_num_frames(self, seq_len: int) -> int: + def get_max_num_frames(self, seq_len: int) -> int: mm_config = self.ctx.get_mm_config() max_images = mm_config.limit_per_prompt.get("image", 1) max_videos = mm_config.limit_per_prompt.get("video", 1) - max_image_tokens = self._get_max_image_tokens() * max_images - max_total_frames = self._get_max_video_frames(seq_len - - max_image_tokens) + max_image_tokens = self.get_max_image_tokens() * max_images + max_total_frames = self.get_max_video_frames(seq_len - + max_image_tokens) max_frames_per_video = min(max_total_frames // max(max_videos, 1), _MAX_FRAMES_PER_VIDEO) return max(max_frames_per_video, 1) - def _get_max_video_tokens(self, seq_len: int) -> int: - target_width, target_height = self._get_image_size_with_most_features() + def get_max_video_tokens(self, seq_len: int) -> int: + target_width, target_height = self.get_image_size_with_most_features() - return self._get_num_video_tokens( + return self.get_num_video_tokens( image_width=target_width, image_height=target_height, - num_frames=self._get_dummy_num_frames(seq_len), + num_frames=self.get_max_num_frames(seq_len), ) + +class LlavaOnevisionDummyDataBuilder( + LlavaDummyDataBuilder[LlavaOnevisionProcessingInfo]): + def get_dummy_processor_inputs( self, seq_len: int, @@ -243,10 +244,11 @@ def get_dummy_processor_inputs( num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) - processor = self._get_hf_processor() + processor = self.info.get_hf_processor() image_token = processor.image_token video_token = processor.video_token - target_width, target_height = self._get_image_size_with_most_features() + target_width, target_height = \ + self.info.get_image_size_with_most_features() mm_data = { "image": @@ -257,7 +259,7 @@ def get_dummy_processor_inputs( self._get_dummy_videos( width=target_width, height=target_height, - num_frames=self._get_dummy_num_frames(seq_len), + num_frames=self.info.get_max_num_frames(seq_len), num_videos=num_videos, ) } @@ -268,11 +270,8 @@ def get_dummy_processor_inputs( ) -class LlavaOnevisionMultiModalProcessor(LlavaOnevisionProcessingMixin, - LlavaNextMultiModalProcessor): - - def _get_profiling_info(self) -> BaseProfilingInfo: - return LlavaOnevisionProfilingInfo(self.ctx) +class LlavaOnevisionMultiModalProcessor( + BaseLlavaNextMultiModalProcessor[LlavaOnevisionProcessingInfo]): def _get_mm_fields_config( self, @@ -303,7 +302,7 @@ def _call_hf_processor( mm_kwargs=mm_kwargs, ) - processor = self._get_hf_processor() + processor = self.info.get_hf_processor() video_token = processor.video_token # LLaVA-OneVision processor doesn't support multiple videos @@ -345,7 +344,7 @@ def _get_prompt_replacements( out_mm_kwargs=out_mm_kwargs, ) - hf_config = self._get_hf_config() + hf_config = self.info.get_hf_config() video_token_id = hf_config.video_token_index def get_video_replacement(item_idx: int): @@ -356,7 +355,7 @@ def get_video_replacement(item_idx: int): num_video_tokens = videos.get_feature_size(item_idx) else: image_size = videos.get_frame_size(item_idx) - num_video_tokens = self._get_num_video_tokens( + num_video_tokens = self.info.get_num_video_tokens( image_width=image_size.width, image_height=image_size.height, num_frames=videos.get_num_frames(item_idx), @@ -393,7 +392,10 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor: return hidden_states -@MULTIMODAL_REGISTRY.register_processor(LlavaOnevisionMultiModalProcessor) +@MULTIMODAL_REGISTRY.register_processor( + LlavaOnevisionMultiModalProcessor, + info=LlavaOnevisionProcessingInfo, + dummy_data=LlavaOnevisionDummyDataBuilder) class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index c8418c14e5fdf..387fccecbf848 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -34,13 +34,12 @@ MultiModalInputsV2, MultiModalKwargs, NestedTensors, PlaceholderRange) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, - ImageSize) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - MultiModalDataItems, ProcessingMixin, - PromptReplacement, - _BoundPromptReplacement, - _PlaceholderInfo) -from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs + ImageSize, MultiModalDataItems) +from vllm.multimodal.processing import (BaseProcessingInfo, + BoundPromptReplacement, + PlaceholderInfo, PromptReplacement) +from vllm.multimodal.processor import BaseMultiModalProcessor +from vllm.multimodal.profiling import BaseDummyDataBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of @@ -302,9 +301,9 @@ def add_image_newline(self, image_features_hd): return image_features_hd_newline -class Phi3VProcessingMixin(ProcessingMixin): +class Phi3VProcessingInfo(BaseProcessingInfo): - def _get_hf_processor( + def get_hf_processor( self, *, num_crops: Optional[int] = None, @@ -314,39 +313,39 @@ def _get_hf_processor( return self.ctx.get_hf_processor() - def _get_num_image_tokens( - self, - *, - image_width: int, - image_height: int, - ) -> int: - processor = self._get_hf_processor() - - return processor.calc_num_image_tokens_from_image_size( # type: ignore - width=image_width, - height=image_height, - ) - - -class Phi3VProfilingInfo(Phi3VProcessingMixin, BaseProfilingInfo): - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: - target_width, target_height = self._get_image_size_with_most_features() + target_width, target_height = self.get_image_size_with_most_features() - max_image_tokens = self._get_num_image_tokens( + max_image_tokens = self.get_num_image_tokens( image_width=target_width, image_height=target_height, ) return {"image": max_image_tokens} - def _get_image_size_with_most_features(self) -> ImageSize: + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + processor = self.get_hf_processor() + + return processor.calc_num_image_tokens_from_image_size( # type: ignore + width=image_width, + height=image_height, + ) + + def get_image_size_with_most_features(self) -> ImageSize: # Result in the max possible feature size (h:w = 16:1) return ImageSize(height=8000, width=50) + +class Phi3VDummyDataBuilder(BaseDummyDataBuilder[Phi3VProcessingInfo]): + def get_dummy_processor_inputs( self, seq_len: int, @@ -354,7 +353,8 @@ def get_dummy_processor_inputs( ) -> ProcessorInputs: num_images = mm_counts.get("image", 0) - target_width, target_height = self._get_image_size_with_most_features() + target_width, target_height = \ + self.info.get_image_size_with_most_features() mm_data = { "image": @@ -363,7 +363,7 @@ def get_dummy_processor_inputs( num_images=num_images) } - hf_processor = self._get_hf_processor() + hf_processor = self.info.get_hf_processor() image_tokens: list[str] = hf_processor.img_tokens # type: ignore return ProcessorInputs( @@ -372,10 +372,7 @@ def get_dummy_processor_inputs( ) -class Phi3VMultiModalProcessor(Phi3VProcessingMixin, BaseMultiModalProcessor): - - def _get_profiling_info(self) -> BaseProfilingInfo: - return Phi3VProfilingInfo(self.ctx) +class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]): def _call_hf_processor( self, @@ -416,10 +413,10 @@ def _get_prompt_replacements( hf_processor_mm_kwargs: Mapping[str, Any], out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: - hf_processor = self._get_hf_processor(**hf_processor_mm_kwargs) + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_tokens: list[str] = hf_processor.img_tokens # type: ignore - tokenizer = self._get_tokenizer() + tokenizer = self.info.get_tokenizer() bos_token_id = tokenizer.bos_token_id assert isinstance(bos_token_id, int) @@ -431,7 +428,7 @@ def get_replacement_phi3v(item_idx: int): num_image_tokens = images.get_feature_size(item_idx) else: image_size = images.get_image_size(item_idx) - num_image_tokens = self._get_num_image_tokens( + num_image_tokens = self.info.get_num_image_tokens( image_width=image_size.width, image_height=image_size.height, ) @@ -451,9 +448,9 @@ def get_replacement_phi3v(item_idx: int): def _apply_prompt_replacements( self, token_ids: list[int], - mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]], + mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]], mm_item_counts: Mapping[str, int], - ) -> tuple[list[int], str, Mapping[str, list[_PlaceholderInfo]]]: + ) -> tuple[list[int], str, Mapping[str, list[PlaceholderInfo]]]: token_ids, text, placeholders = super()._apply_prompt_replacements( token_ids=token_ids, mm_prompt_repls=mm_prompt_repls, @@ -466,7 +463,7 @@ def _apply_prompt_replacements( token_ids = [token_ids[0], *token_ids[2:]] placeholders = { modality: [ - _PlaceholderInfo( + PlaceholderInfo( modality=p.modality, item_idx=p.item_idx, start_idx=p.start_idx - 1, @@ -499,7 +496,9 @@ def apply( return result -@MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor) +@MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor, + info=Phi3VProcessingInfo, + dummy_data=Phi3VDummyDataBuilder) class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 7012ddc66cd9c..aea7dc8dd8fea 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -38,11 +38,11 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, NestedTensors) -from vllm.multimodal.parse import AudioProcessorItems, MultiModalDataParser -from vllm.multimodal.processing import (BaseMultiModalProcessor, - MultiModalDataItems, ProcessingMixin, - PromptReplacement) -from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs +from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems, + MultiModalDataParser) +from vllm.multimodal.processing import BaseProcessingInfo, PromptReplacement +from vllm.multimodal.processor import BaseMultiModalProcessor +from vllm.multimodal.profiling import BaseDummyDataBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from .interfaces import SupportsMultiModal, SupportsPP @@ -80,12 +80,12 @@ def _get_feat_extract_output_lengths(input_lengths: torch.Tensor): return feat_lengths, output_lengths -class Qwen2AudioProcessingMixin(ProcessingMixin): +class Qwen2AudioProcessingInfo(BaseProcessingInfo): - def _get_hf_config(self): + def get_hf_config(self): return self.ctx.get_hf_config(Qwen2AudioConfig) - def _get_hf_processor( + def get_hf_processor( self, *, # Ignored in initialization @@ -93,36 +93,37 @@ def _get_hf_processor( ) -> Qwen2AudioProcessor: return self.ctx.get_hf_processor(Qwen2AudioProcessor) - def _get_feature_extractor( + def get_feature_extractor( self, *, # Ignored in initialization sampling_rate: Optional[int] = None, ) -> WhisperFeatureExtractor: - hf_processor = self._get_hf_processor(sampling_rate=sampling_rate) + hf_processor = self.get_hf_processor(sampling_rate=sampling_rate) feature_extractor = hf_processor.feature_extractor # type: ignore assert isinstance(feature_extractor, WhisperFeatureExtractor) return feature_extractor - -class Qwen2AudioProfilingInfo(Qwen2AudioProcessingMixin, BaseProfilingInfo): - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"audio": None} def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: - hf_config = self._get_hf_config() + hf_config = self.get_hf_config() max_source_positions = hf_config.audio_config.max_source_positions max_output_lengths = (max_source_positions - 2) // 2 + 1 return {"audio": max_output_lengths} + +class Qwen2AudioDummyDataBuilder(BaseDummyDataBuilder[Qwen2AudioProcessingInfo] + ): + def get_dummy_processor_inputs( self, seq_len: int, mm_counts: Mapping[str, int], ) -> ProcessorInputs: - feature_extractor = self._get_feature_extractor() + feature_extractor = self.info.get_feature_extractor() sampling_rate = feature_extractor.sampling_rate audio_len = feature_extractor.chunk_length * sampling_rate @@ -139,14 +140,11 @@ def get_dummy_processor_inputs( ) -class Qwen2AudioMultiModalProcessor(Qwen2AudioProcessingMixin, - BaseMultiModalProcessor): - - def _get_profiling_info(self) -> BaseProfilingInfo: - return Qwen2AudioProfilingInfo(self.ctx) +class Qwen2AudioMultiModalProcessor( + BaseMultiModalProcessor[Qwen2AudioProcessingInfo]): def _get_data_parser(self) -> MultiModalDataParser: - feature_extractor = self._get_feature_extractor() + feature_extractor = self.info.get_feature_extractor() return MultiModalDataParser(target_sr=feature_extractor.sampling_rate) def _call_hf_processor( @@ -161,7 +159,7 @@ def _call_hf_processor( if audios: mm_data["audios"] = audios - feature_extractor = self._get_feature_extractor(**mm_kwargs) + feature_extractor = self.info.get_feature_extractor(**mm_kwargs) mm_kwargs = dict( **mm_kwargs, sampling_rate=feature_extractor.sampling_rate, @@ -194,7 +192,7 @@ def _get_prompt_replacements( hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: - hf_config = self._get_hf_config() + hf_config = self.info.get_hf_config() placeholder = hf_config.audio_token_index feature_attention_mask = out_mm_kwargs.get("feature_attention_mask") @@ -234,10 +232,12 @@ def _always_apply_prompt_replacements(self) -> bool: # has already performed processing for multi-audio input when the input # audios are short (the corresponding placeholders may take up fewer # tokens than the number of audio items) - return not hasattr(self._get_hf_processor(), "audio_token") + return not hasattr(self.info.get_hf_processor(), "audio_token") -@MULTIMODAL_REGISTRY.register_processor(Qwen2AudioMultiModalProcessor) +@MULTIMODAL_REGISTRY.register_processor(Qwen2AudioMultiModalProcessor, + info=Qwen2AudioProcessingInfo, + dummy_data=Qwen2AudioDummyDataBuilder) class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index a5c2fb9e84df3..d60656d140bf8 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -57,11 +57,10 @@ MultiModalFieldConfig, MultiModalKwargs, NestedTensors, VideoItem) from vllm.multimodal.parse import (ImageSize, ModalityDataItems, - MultiModalDataParser) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - MultiModalDataItems, ProcessingMixin, - PromptReplacement) -from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs + MultiModalDataItems, MultiModalDataParser) +from vllm.multimodal.processing import BaseProcessingInfo, PromptReplacement +from vllm.multimodal.processor import BaseMultiModalProcessor +from vllm.multimodal.profiling import BaseDummyDataBuilder, ProcessorInputs from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import uses_mrope @@ -709,12 +708,12 @@ def _parse_video_data( return super()._parse_video_data(data) -class Qwen2VLProcessingMixin(ProcessingMixin): +class Qwen2VLProcessingInfo(BaseProcessingInfo): - def _get_hf_config(self): + def get_hf_config(self): return self.ctx.get_hf_config(Qwen2VLConfig) - def _get_hf_processor( + def get_hf_processor( self, *, min_pixels: Optional[int] = None, @@ -736,18 +735,27 @@ def _get_hf_processor( return hf_processor - def _get_image_processor( + def get_image_processor( self, *, min_pixels: Optional[int] = None, max_pixels: Optional[int] = None, ): - hf_processor = self._get_hf_processor(min_pixels=min_pixels, - max_pixels=max_pixels) + hf_processor = self.get_hf_processor(min_pixels=min_pixels, + max_pixels=max_pixels) image_processor = hf_processor.image_processor # type: ignore assert isinstance(image_processor, Qwen2VLImageProcessor) return image_processor + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None, "video": None} + + def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: + return { + "image": self.get_max_image_tokens(), + "video": self.get_max_video_tokens(seq_len), + } + def _get_vision_info( self, *, @@ -756,13 +764,13 @@ def _get_vision_info( num_frames: int = 1, do_resize: bool = True, ) -> tuple[ImageSize, int]: - hf_config = self._get_hf_config() + hf_config = self.get_hf_config() vision_config = hf_config.vision_config patch_size = vision_config.patch_size merge_size = vision_config.spatial_merge_size temporal_patch_size = vision_config.temporal_patch_size - image_processor = self._get_image_processor() + image_processor = self.get_image_processor() if do_resize: resized_height, resized_width = smart_resize( @@ -787,7 +795,7 @@ def _get_vision_info( return preprocessed_size, num_vision_tokens - def _get_num_image_tokens( + def get_num_image_tokens( self, *, image_width: int, @@ -799,7 +807,7 @@ def _get_num_image_tokens( ) return num_image_tokens - def _get_num_video_tokens( + def get_num_video_tokens( self, *, image_width: int, @@ -813,41 +821,29 @@ def _get_num_video_tokens( ) return num_video_tokens - -class Qwen2VLProfilingInfo(Qwen2VLProcessingMixin, BaseProfilingInfo): - - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: - return {"image": None, "video": None} - - def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: - return { - "image": self._get_max_image_tokens(), - "video": self._get_max_video_tokens(seq_len), - } - - def _get_image_size_with_most_features(self) -> ImageSize: + def get_image_size_with_most_features(self) -> ImageSize: max_image_size, _ = self._get_vision_info( image_width=9999999, image_height=9999999, ) return max_image_size - def _get_max_image_tokens(self) -> int: - target_width, target_height = self._get_image_size_with_most_features() + def get_max_image_tokens(self) -> int: + target_width, target_height = self.get_image_size_with_most_features() - return self._get_num_image_tokens( + return self.get_num_image_tokens( image_width=target_width, image_height=target_height, ) - def _get_max_video_frames(self, max_tokens: int) -> int: - target_width, target_height = self._get_image_size_with_most_features() + def get_max_video_frames(self, max_tokens: int) -> int: + target_width, target_height = self.get_image_size_with_most_features() num_frames = 0 while True: next_num_frames = num_frames + 1 - next_max_tokens = self._get_num_video_tokens( + next_max_tokens = self.get_num_video_tokens( image_width=target_width, image_height=target_height, num_frames=next_num_frames, @@ -860,14 +856,14 @@ def _get_max_video_frames(self, max_tokens: int) -> int: return num_frames - def _get_dummy_num_frames(self, seq_len: int) -> int: + def get_max_num_frames(self, seq_len: int) -> int: mm_config = self.ctx.get_mm_config() max_images = mm_config.limit_per_prompt.get("image", 1) max_videos = mm_config.limit_per_prompt.get("video", 1) - max_image_tokens = self._get_max_image_tokens() * max_images - max_total_frames = self._get_max_video_frames(seq_len - - max_image_tokens) + max_image_tokens = self.get_max_image_tokens() * max_images + max_total_frames = self.get_max_video_frames(seq_len - + max_image_tokens) num_frames = max(max_total_frames // max(max_videos, 1), 1) @@ -877,15 +873,18 @@ def _get_dummy_num_frames(self, seq_len: int) -> int: return num_frames - def _get_max_video_tokens(self, seq_len: int) -> int: - target_width, target_height = self._get_image_size_with_most_features() + def get_max_video_tokens(self, seq_len: int) -> int: + target_width, target_height = self.get_image_size_with_most_features() - return self._get_num_video_tokens( + return self.get_num_video_tokens( image_width=target_width, image_height=target_height, - num_frames=self._get_dummy_num_frames(seq_len), + num_frames=self.get_max_num_frames(seq_len), ) + +class Qwen2VLDummyDataBuilder(BaseDummyDataBuilder[Qwen2VLProcessingInfo]): + def get_dummy_processor_inputs( self, seq_len: int, @@ -894,10 +893,11 @@ def get_dummy_processor_inputs( num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) - hf_processor = self._get_hf_processor() + hf_processor = self.info.get_hf_processor() image_token: str = hf_processor.image_token video_token: str = hf_processor.video_token - target_width, target_height = self._get_image_size_with_most_features() + target_width, target_height = \ + self.info.get_image_size_with_most_features() mm_data = { "image": @@ -908,7 +908,7 @@ def get_dummy_processor_inputs( self._get_dummy_videos( width=target_width, height=target_height, - num_frames=self._get_dummy_num_frames(seq_len), + num_frames=self.info.get_max_num_frames(seq_len), num_videos=num_videos, ) } @@ -919,11 +919,8 @@ def get_dummy_processor_inputs( ) -class Qwen2VLMultiModalProcessor(Qwen2VLProcessingMixin, - BaseMultiModalProcessor): - - def _get_profiling_info(self) -> BaseProfilingInfo: - return Qwen2VLProfilingInfo(self.ctx) +class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo] + ): def _get_data_parser(self) -> MultiModalDataParser: return Qwen2MultiModalDataParser() @@ -934,8 +931,9 @@ def _get_prompt_replacements( hf_processor_mm_kwargs: Mapping[str, Any], out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: - hf_processor = self._get_hf_processor(**hf_processor_mm_kwargs) - image_processor = self._get_image_processor(**hf_processor_mm_kwargs) + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + image_processor = self.info.get_image_processor( + **hf_processor_mm_kwargs) # NOTE: Only Qwen2VLProcessor in transformers 4.47.0 has # image_token and video_token registered @@ -991,7 +989,9 @@ def _get_mm_fields_config( ) -@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor) +@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor, + info=Qwen2VLProcessingInfo, + dummy_data=Qwen2VLDummyDataBuilder) class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP): packed_modules_mapping = { diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index ecafd157b1d61..1a403941f803f 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -24,11 +24,10 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, NestedTensors) -from vllm.multimodal.parse import MultiModalDataParser -from vllm.multimodal.processing import (BaseMultiModalProcessor, - MultiModalDataItems, ProcessingMixin, - PromptReplacement) -from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs +from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser +from vllm.multimodal.processing import BaseProcessingInfo, PromptReplacement +from vllm.multimodal.processor import BaseMultiModalProcessor +from vllm.multimodal.profiling import BaseDummyDataBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.ultravox import UltravoxConfig @@ -59,9 +58,9 @@ class UltravoxAudioEmbeddingInputs(TypedDict): UltravoxAudioEmbeddingInputs] -class UltravoxProcessingMixin(ProcessingMixin): +class UltravoxProcessingInfo(BaseProcessingInfo): - def _get_hf_processor( + def get_hf_processor( self, *, # Ignored in initialization @@ -76,37 +75,37 @@ def _get_hf_processor( hf_processor.audio_token_replacement = _AUDIO_PLACEHOLDER_OVERRIDE return hf_processor - def _get_feature_extractor( + def get_feature_extractor( self, *, # Ignored in initialization sampling_rate: Optional[int] = None, ) -> WhisperFeatureExtractor: - hf_processor = self._get_hf_processor(sampling_rate=sampling_rate) + hf_processor = self.get_hf_processor(sampling_rate=sampling_rate) audio_processor = hf_processor.audio_processor # type: ignore feature_extractor = audio_processor.feature_extractor # type: ignore assert isinstance(feature_extractor, WhisperFeatureExtractor) return feature_extractor - -class UltravoxProfilingInfo(UltravoxProcessingMixin, BaseProfilingInfo): - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"audio": None} def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: - feature_extractor = self._get_feature_extractor() + feature_extractor = self.get_feature_extractor() max_audio_tokens = math.ceil(feature_extractor.chunk_length * _AUDIO_TOKENS_PER_SECOND) return {"audio": max_audio_tokens} + +class UltravoxDummyDataBuilder(BaseDummyDataBuilder[UltravoxProcessingInfo]): + def get_dummy_processor_inputs( self, seq_len: int, mm_counts: Mapping[str, int], ) -> ProcessorInputs: - feature_extractor = self._get_feature_extractor() + feature_extractor = self.info.get_feature_extractor() sampling_rate = feature_extractor.sampling_rate audio_len = feature_extractor.chunk_length * sampling_rate @@ -123,14 +122,11 @@ def get_dummy_processor_inputs( ) -class UltravoxMultiModalProcessor(UltravoxProcessingMixin, - BaseMultiModalProcessor): - - def _get_profiling_info(self) -> BaseProfilingInfo: - return UltravoxProfilingInfo(self.ctx) +class UltravoxMultiModalProcessor( + BaseMultiModalProcessor[UltravoxProcessingInfo]): def _get_data_parser(self) -> MultiModalDataParser: - feature_extractor = self._get_feature_extractor() + feature_extractor = self.info.get_feature_extractor() return MultiModalDataParser(target_sr=feature_extractor.sampling_rate) def _call_hf_processor( @@ -141,7 +137,7 @@ def _call_hf_processor( ) -> BatchFeature: # Text-only input not supported in composite processor if not mm_data: - tokenizer = self._get_tokenizer() + tokenizer = self.info.get_tokenizer() prompt_ids = tokenizer.encode( prompt, @@ -160,7 +156,7 @@ def _call_hf_processor( mm_kwargs=mm_kwargs, ) - feature_extractor = self._get_feature_extractor() + feature_extractor = self.info.get_feature_extractor() mm_kwargs = dict( **mm_kwargs, sampling_rate=feature_extractor.sampling_rate, @@ -208,7 +204,7 @@ def _get_prompt_replacements( hf_processor_mm_kwargs: Mapping[str, Any], out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: - hf_processor = self._get_hf_processor(**hf_processor_mm_kwargs) + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) placeholder = hf_processor.audio_token_replacement # type: ignore def get_replacement_ultravox(item_idx: int): @@ -342,7 +338,9 @@ def forward( return hidden_states -@MULTIMODAL_REGISTRY.register_processor(UltravoxMultiModalProcessor) +@MULTIMODAL_REGISTRY.register_processor(UltravoxMultiModalProcessor, + info=UltravoxProcessingInfo, + dummy_data=UltravoxDummyDataBuilder) class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP): hf_to_vllm_mapper = WeightsMapper( diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 41113cd85bd16..5571f9fbc61b7 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -4,23 +4,18 @@ from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence from dataclasses import dataclass, field from functools import lru_cache -from typing import Any, NamedTuple, Optional, Protocol, TypeVar, Union +from typing import NamedTuple, Optional, Protocol, TypeVar, Union -from transformers import BatchFeature, PretrainedConfig, ProcessorMixin +from transformers import PretrainedConfig, ProcessorMixin -from vllm import envs -from vllm.inputs import DummyData, InputProcessingContext +from vllm.inputs import InputProcessingContext from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens, encode_tokens) from vllm.utils import LRUCache, flatten_2d_lists, full_groupby from .hasher import MultiModalHasher -from .inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalInputsV2, MultiModalKwargs, - MultiModalKwargsItem, PlaceholderRange) -from .parse import MultiModalDataItems, MultiModalDataParser -from .profiling import BaseProfilingInfo +from .inputs import MultiModalKwargsItem, PlaceholderRange logger = init_logger(__name__) @@ -46,8 +41,8 @@ class PromptReplacement: if it does not depend on the input. """ - def bind(self, tokenizer: AnyTokenizer) -> "_BoundPromptReplacement": - return _BoundPromptReplacement( + def bind(self, tokenizer: AnyTokenizer) -> "BoundPromptReplacement": + return BoundPromptReplacement( tokenizer=tokenizer, modality=self.modality, _target=self.target, @@ -128,7 +123,7 @@ def token_ids(self) -> list[int]: @dataclass -class _BoundPromptReplacement: +class BoundPromptReplacement: tokenizer: AnyTokenizer = field(repr=False) modality: str @@ -207,7 +202,7 @@ def iter_token_matches( @dataclass(repr=False) class _PromptReplacementMatch(ABC): - prompt_repl: _BoundPromptReplacement + prompt_repl: BoundPromptReplacement @property def modality(self) -> str: @@ -255,7 +250,7 @@ def end_idx(self) -> int: @dataclass -class _PlaceholderInfo: +class PlaceholderInfo: modality: str item_idx: int start_idx: int @@ -274,7 +269,7 @@ def to_range(self) -> PlaceholderRange: def find_token_matches( prompt: list[int], - prompt_repls: Sequence[_BoundPromptReplacement], + prompt_repls: Sequence[BoundPromptReplacement], ) -> list[_PromptReplacementTokenMatch]: """Return each target of :code:`prompt_repls` found in :code:`prompt`.""" return [ @@ -286,7 +281,7 @@ def find_token_matches( def find_text_matches( prompt: str, - prompt_repls: Sequence[_BoundPromptReplacement], + prompt_repls: Sequence[BoundPromptReplacement], ) -> list[_PromptReplacementTextMatch]: """Return each target of :code:`prompt_repls` found in :code:`prompt`.""" return [ @@ -390,9 +385,9 @@ def replace_text_matches( def _iter_modality_placeholders( prompt: list[int], modality: str, - modality_repls: Sequence[_BoundPromptReplacement], + modality_repls: Sequence[BoundPromptReplacement], modal_item_count: int, -) -> Iterable[_PlaceholderInfo]: +) -> Iterable[PlaceholderInfo]: if modal_item_count == 0: return @@ -413,7 +408,7 @@ def _iter_modality_placeholders( continue if prompt[start_idx:end_idx] == repl_tokens: - yield _PlaceholderInfo( + yield PlaceholderInfo( modality=modality, item_idx=item_idx, start_idx=start_idx, @@ -434,10 +429,10 @@ def _iter_modality_placeholders( def _iter_placeholders( - mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]], + mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]], prompt: list[int], mm_item_counts: Mapping[str, int], -) -> Iterable[_PlaceholderInfo]: +) -> Iterable[PlaceholderInfo]: """ For each modality, yield each set of placeholder tokens found in :code:`prompt`. @@ -455,10 +450,10 @@ def _iter_placeholders( def find_mm_placeholders( - mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]], + mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]], prompt: list[int], mm_item_counts: Mapping[str, int], -) -> Mapping[str, list[_PlaceholderInfo]]: +) -> Mapping[str, list[PlaceholderInfo]]: it = _iter_placeholders(mm_prompt_repls, prompt, mm_item_counts) return dict(full_groupby_modality(it)) @@ -524,645 +519,50 @@ def put( self._cache.put(cache_key, output_kwargs) -class ProcessingMixin: - """ - Contains helper functions to perform processing. +class BaseProcessingInfo: + """Base class containing information to perform processing.""" - Not to be confused with :class:`transformers.ProcessorMixin`. - """ - ctx: InputProcessingContext + def __init__(self, ctx: InputProcessingContext) -> None: + super().__init__() + + self.ctx = ctx + + @property + def model_id(self) -> str: + return self.ctx.model_config.model - def _get_tokenizer(self) -> AnyTokenizer: + def get_tokenizer(self) -> AnyTokenizer: return self.ctx.tokenizer - def _get_hf_config(self) -> PretrainedConfig: + def get_hf_config(self) -> PretrainedConfig: return self.ctx.get_hf_config() - def _get_hf_processor(self, **kwargs: object) -> ProcessorMixin: + def get_hf_processor(self, **kwargs: object) -> ProcessorMixin: """ Subclasses can override this method to handle specific kwargs from model config or user inputs. """ return self.ctx.get_hf_processor(**kwargs) - -class BaseMultiModalProcessor(ProcessingMixin, ABC): - """ - Abstract base class to process multi-modal inputs to be used in vLLM. - - Not to be confused with :class:`transformers.ProcessorMixin`. - """ - - def __init__(self, - ctx: InputProcessingContext, - *, - cache: Optional[ProcessingCache] = None, - enable_sanity_checks: bool = True) -> None: - super().__init__() - - self.ctx = ctx - self.cache = cache - self.enable_sanity_checks = enable_sanity_checks - - self.data_parser = self._get_data_parser() - self.profiling_info = self._get_profiling_info() - - def __call__( - self, - prompt: str, - mm_data: MultiModalDataDict, - hf_processor_mm_kwargs: Mapping[str, object], - ) -> MultiModalInputsV2: - return self.apply(prompt, mm_data, hf_processor_mm_kwargs) - - def _get_data_parser(self) -> MultiModalDataParser: + @abstractmethod + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: """ - Construct a parser to preprocess multi-modal data items - before passing them to :meth:`_get_hf_mm_data`. + Return the maximum supported number of items for each modality. - You can support additional modalities by creating a subclass - of :class:`MultiModalDataParser` that has additional subparsers. - """ - return MultiModalDataParser() + A value of `None` means unlimited number of items. - def _get_profiling_info(self) -> BaseProfilingInfo: - """ - Get the profiling information to find the worst-case memory usage of - the model. + Omitting a modality from the returned dictionary means that + it is not supported at all. """ raise NotImplementedError - def _to_mm_items( - self, - mm_data: MultiModalDataDict, - ) -> MultiModalDataItems: - """ - Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems` - before passing them to :meth:`_get_hf_mm_data`. - """ - mm_items = self.data_parser.parse_mm_data(mm_data) - - mm_limits = self.ctx.get_mm_config().limit_per_prompt - for modality, items in mm_items.items(): - limit = mm_limits.get(modality, 1) - if len(items) > limit: - raise ValueError( - f"You set {modality}={limit} (or defaulted to 1) in " - f"`--limit-mm-per-prompt`, but passed {len(items)} " - f"{modality} items in the same prompt.") - - return mm_items - - @abstractmethod - def _get_mm_fields_config( - self, - hf_inputs: BatchFeature, - hf_processor_mm_kwargs: Mapping[str, object], - ) -> Mapping[str, MultiModalFieldConfig]: - """Given the HF-processed data, output the metadata of each field.""" - raise NotImplementedError - @abstractmethod - def _get_prompt_replacements( - self, - mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: - """ - Given the original multi-modal items for this modality - and HF-processed data, output the replacements to perform. - - Notes: - - You should not assume that HF processor always performs prompt - replacement: in :meth:`_apply_hf_processor_missing`, this method - is called on text-only and multimodal-only inputs separately, - instead of passing them in the same call. - - The replacement information returned by this method is also used - to determine the placeholder token positions for each multi-modal - item. - """ - raise NotImplementedError - - def _find_mm_placeholders( - self, - mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]], - new_token_ids: list[int], - mm_item_counts: Mapping[str, int], - ) -> Mapping[str, list[_PlaceholderInfo]]: - return find_mm_placeholders(mm_prompt_repls, new_token_ids, - mm_item_counts) - - def _get_hf_mm_data( - self, - mm_items: MultiModalDataItems, - ) -> tuple[dict[str, Any], dict[str, Any]]: - processor_data = dict[str, Any]() - passthrough_data = dict[str, Any]() - - for items in mm_items.values(): - processor_data.update(items.get_processor_data()) - passthrough_data.update(items.get_passthrough_data()) - - return processor_data, passthrough_data - - def _call_hf_processor( - self, - prompt: str, - # Not to be confused with `mm_data` in `self.apply`. - # This refers to the data to be passed to HF processor. - mm_data: Mapping[str, object], - mm_kwargs: Mapping[str, object], - ) -> BatchFeature: - """ - Call the HF processor on the prompt text and - associated multi-modal data. - """ - return self.ctx.call_hf_processor( - self._get_hf_processor(**mm_kwargs), - dict(text=prompt, **mm_data), - mm_kwargs, - ) - - def _apply_hf_processor( - self, - prompt_text: str, - mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], - ) -> tuple[list[int], MultiModalKwargs]: - """ - Wrapper of :meth:`_call_hf_processor` that applies - additional pre-processing and post-processing. - """ - processor_data, passthrough_data = self._get_hf_mm_data(mm_items) - - processed_data = self._call_hf_processor( - prompt=prompt_text, - mm_data=processor_data, - mm_kwargs=hf_processor_mm_kwargs, - ) - processed_data.update(passthrough_data) - - prompt_ids, = processed_data.pop("input_ids").tolist() - - mm_kwargs = MultiModalKwargs.from_hf_inputs( - processed_data, - self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs), - ) - - return prompt_ids, mm_kwargs - - def _apply_hf_processor_missing( - self, - prompt_text: str, - mm_missing_data_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], - ): - """ - Apply the HF processor on the full prompt text, but only on the - multi-modal data that are missing from the cache. - - Note: - We pass prompt text and multi-modal data into the HF processor - in separate calls to avoid HF prompt replacement being done for - cached items; instead, we rely on our own prompt replacement logic - (:meth:`_get_prompt_replacements`) for the full text. - """ - mm_missing_counts = mm_missing_data_items.get_all_counts() - - prompt_ids, _ = self._apply_hf_processor( - prompt_text=prompt_text, - mm_items=MultiModalDataItems({}), - hf_processor_mm_kwargs={}, - ) - - # Some HF processors (e.g. Qwen2-VL) expect corresponding - # multi-modal tokens to be in the prompt text - dummy_inputs = self.profiling_info.get_dummy_processor_inputs( - self.ctx.model_config.max_model_len, - mm_missing_counts, - ) - - _, mm_missing_kwargs = self._apply_hf_processor( - prompt_text=dummy_inputs.prompt_text, - mm_items=mm_missing_data_items, - hf_processor_mm_kwargs=hf_processor_mm_kwargs, - ) - - return prompt_ids, mm_missing_kwargs - - def _cached_apply_hf_processor( - self, - prompt_text: str, - mm_data_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], - ) -> tuple[list[int], MultiModalKwargs]: - """ - Apply the HF processor on the full prompt text, - caching the results and reusing cached results. - """ - cache = self.cache - model_id = self.ctx.model_config.model - - _, passthrough_data = self._get_hf_mm_data(mm_data_items) - if cache is None or passthrough_data: - return self._apply_hf_processor( - prompt_text=prompt_text, - mm_items=mm_data_items, - hf_processor_mm_kwargs=hf_processor_mm_kwargs, - ) - - mm_maybe_cached_kw_items = { - modality: [ - cache.get(model_id, modality, item, hf_processor_mm_kwargs) - for item in items - ] - for modality, items in mm_data_items.items() - } - - mm_missing_idxs = { - modality: - [idx for idx, item in enumerate(kw_items) if item is None] - for modality, kw_items in mm_maybe_cached_kw_items.items() - } - mm_missing_data = { - modality: [mm_data_items[modality][idx] for idx in idxs] - for modality, idxs in mm_missing_idxs.items() - } - mm_missing_data_items = self._to_mm_items(mm_missing_data) - - prompt_ids, mm_missing_kwargs = self._apply_hf_processor_missing( - prompt_text=prompt_text, - mm_missing_data_items=mm_missing_data_items, - hf_processor_mm_kwargs=hf_processor_mm_kwargs, - ) - - mm_missing_next_idx = { - modality: 0 - for modality in mm_missing_data_items - } - - merged_kw_items = list[MultiModalKwargsItem]() - for modality, kw_items in mm_maybe_cached_kw_items.items(): - for idx, kw_item in enumerate(kw_items): - if kw_item is None: - kw_item = mm_missing_kwargs.get_item( - modality, - mm_missing_next_idx[modality], - ) - - cache.put( - model_id, - modality, - mm_data_items[modality][idx], - hf_processor_mm_kwargs, - kw_item, - ) - - mm_missing_next_idx[modality] += 1 - - merged_kw_items.append(kw_item) - - if self.enable_sanity_checks: - mm_missing_counts = mm_missing_data_items.get_all_counts() - assert all( - item_count == mm_missing_counts[modality] - for modality, item_count in mm_missing_next_idx.items()), dict( - mm_missing_next_idx=mm_missing_next_idx, - mm_missing_counts=mm_missing_counts) - - mm_kwargs = MultiModalKwargs.from_items(merged_kw_items) - - return prompt_ids, mm_kwargs - - def _bind_and_group_repls( - self, - prompt_repls: list[PromptReplacement], - ) -> dict[str, list[_BoundPromptReplacement]]: - tokenizer = self._get_tokenizer() - - it = (prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls) - return dict(full_groupby_modality(it)) - - def _always_apply_prompt_replacements(self) -> bool: - """ - A flag which can be overridden so that - :meth:`_apply_prompt_replacements` is always called even if we - detect that HF has performed processing via - :meth:`_find_placeholders_by_modality`. - - This is useful in cases where :meth:`_find_placeholders_by_modality` - cannot be reliably used to detect whether HF has performed processing. - """ - return False - - def _apply_prompt_replacements( - self, - token_ids: list[int], - mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]], - mm_item_counts: Mapping[str, int], - ) -> tuple[list[int], str, Mapping[str, list[_PlaceholderInfo]]]: - tokenizer = self._get_tokenizer() - - mm_token_matches = { - modality: find_token_matches(token_ids, prompt_repls) - for modality, prompt_repls in mm_prompt_repls.items() - } - mm_match_counts = { - modality: len(matches) - for modality, matches in mm_token_matches.items() - } - - # If the search text does not represent a special token, - # it may have different token IDs in the prompt, because - # the tokens may go across the boundaries of the search text. - # ---- - # e.g. when searching for "foo" in "food", if "food" itself makes - # up a token, then the token ID of "foo" will not appear at all - # ---- - # Since it is inefficient to search for all possible tokenizations - # of the search text in the prompt, we instead perform string - # replacement on the decoded token IDs, then encode them back. - if all( - mm_match_counts.get(modality, 0) >= item_count - for modality, item_count in mm_item_counts.items() - ): # yapf: disable - token_ids = replace_token_matches( - token_ids, - mm_token_matches, - mm_item_counts, - ) - - text = decode_tokens(tokenizer, token_ids) - matched_repls = { - modality: [match.prompt_repl for match in token_matches] - for modality, token_matches in mm_token_matches.items() - } - else: - text = decode_tokens(tokenizer, token_ids) - - mm_text_matches = { - modality: find_text_matches(text, prompt_repls) - for modality, prompt_repls in mm_prompt_repls.items() - } - text = replace_text_matches( - text, - mm_text_matches, - mm_item_counts, - ) - - token_ids = encode_tokens(tokenizer, - text, - add_special_tokens=False) - matched_repls = { - modality: [match.prompt_repl for match in token_matches] - for modality, token_matches in mm_text_matches.items() - } - - placeholders = self._find_mm_placeholders( - matched_repls, - token_ids, - mm_item_counts, - ) - - return token_ids, text, placeholders - - def _validate_mm_kwargs( - self, - mm_kwargs: MultiModalKwargs, - mm_item_counts: Mapping[str, int], - ) -> None: - for modality, item_count in mm_item_counts.items(): - if modality in mm_kwargs.modalities: - items = mm_kwargs.get_items(modality) - else: - items = [] - - if len(items) != item_count: - raise RuntimeError( - f"Expected there to be {item_count} {modality} items in " - f"keyword arguments corresponding to {item_count} " - f"{modality} data items, but only found {len(items)}! " - "There is likely a problem with your " - "implementation of merged multi-modal processor for this " - "model (usually arising from an inconsistency between " - "`_call_hf_processor` and `_get_mm_fields_config`).") - - def _validate_mm_placeholders( - self, - mm_placeholders: Mapping[str, list[_PlaceholderInfo]], - mm_item_counts: Mapping[str, int], - *, - allow_missing: bool = False, - ) -> Mapping[str, int]: - missing_repl_counts = dict[str, int]() - - for modality, item_count in mm_item_counts.items(): - placeholders = mm_placeholders.get(modality, []) - - if len(placeholders) != item_count and not allow_missing: - raise RuntimeError( - f"Expected there to be {item_count} prompt replacements " - f"corresponding to {item_count} {modality} items, but only " - f"found {len(placeholders)} prompt replacements! Either " - "the prompt text has missing/incorrect tokens for " - "multi-modal inputs, or there is a problem with your " - "implementation of merged multi-modal processor for this " - "model (usually arising from an inconsistency between " - "`_call_hf_processor` and `_get_prompt_replacements`).") - - missing_repl_counts[modality] = item_count - len(placeholders) - - return missing_repl_counts - - def apply( - self, - prompt_text: str, - mm_data: MultiModalDataDict, - hf_processor_mm_kwargs: Mapping[str, object], - ) -> MultiModalInputsV2: + def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: """ - Process multi-modal inputs to be used in vLLM. + Get the maximum possible number of tokens per data item + for each modality. - The main steps are: - - 1. Apply HF Processor on prompt text and multi-modal data together, - outputting token IDs and processed tensors. - 2. Find and replace sequences in the token IDs with placeholder tokens. - The number of placeholder tokens equals the feature size of the - multi-modal data outputted by the multi-modal encoder. - 3. Extract information about the placeholder tokens from the - processed token IDs. + The dictionary returned by this method should have the same + keys as that returned by :meth:`get_supported_mm_limits`. """ - mm_items = self._to_mm_items(mm_data) - - # Create MM hashes (only used in V1) - # TODO: Use these hash keys for caching operations in apply_hf_processor - # instead of rehashing. - - if envs.VLLM_USE_V1: - model_id = self.ctx.model_config.model - mm_hashes = { - modality: [ - MultiModalHasher.hash_kwargs(model_id=model_id, - **{modality: item}, - **hf_processor_mm_kwargs) - for item in items - ] - for modality, items in mm_items.items() - } - else: - mm_hashes = None - - prompt_ids, mm_kwargs = self._cached_apply_hf_processor( - prompt_text, - mm_items, - hf_processor_mm_kwargs, - ) - - unbound_prompt_repls = self._get_prompt_replacements( - mm_items, - hf_processor_mm_kwargs, - mm_kwargs, - ) - mm_prompt_repls = self._bind_and_group_repls(unbound_prompt_repls) - - mm_item_counts = mm_items.get_all_counts() - self._validate_mm_kwargs(mm_kwargs, mm_item_counts) - - hf_mm_placeholders = self._find_mm_placeholders( - mm_prompt_repls, - prompt_ids, - mm_item_counts, - ) - - if self._always_apply_prompt_replacements(): - mm_missing_repl_counts = mm_item_counts - mm_missing_repls = dict(mm_prompt_repls) - else: - mm_missing_repl_counts = self._validate_mm_placeholders( - hf_mm_placeholders, - mm_item_counts, - allow_missing=True, - ) - - mm_missing_repls = dict[str, list[_BoundPromptReplacement]]() - for modality, missing_repl_count in mm_missing_repl_counts.items(): - if missing_repl_count == 0: - mm_missing_repls[modality] = [] - elif missing_repl_count == mm_item_counts.get(modality, 0): - mm_missing_repls[modality] = mm_prompt_repls[modality] - else: - raise ValueError("Partial prompt replacement within " - f"{modality=} is not supported") - - # If HF processor already inserts placeholder tokens, - # there is no need for us to insert them - if all(len(repls) == 0 for repls in mm_missing_repls.items()): - tokenizer = self._get_tokenizer() - prompt_text = decode_tokens(tokenizer, prompt_ids) - mm_placeholders = hf_mm_placeholders - else: - ( - prompt_ids, - prompt_text, - missing_mm_placeholders, - ) = self._apply_prompt_replacements( - prompt_ids, - mm_missing_repls, - mm_missing_repl_counts, - ) - - mm_placeholders = {**hf_mm_placeholders, **missing_mm_placeholders} - - self._validate_mm_placeholders(mm_placeholders, mm_item_counts) - - mm_placeholder_ranges = { - modality: [item.to_range() for item in placeholders] - for modality, placeholders in mm_placeholders.items() - } - - return MultiModalInputsV2( - type="multimodal", - prompt=prompt_text, - prompt_token_ids=prompt_ids, - mm_kwargs=mm_kwargs, - mm_hashes=mm_hashes, - mm_placeholders=mm_placeholder_ranges, - ) - - def _get_dummy_mm_inputs( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> MultiModalInputsV2: - profiling = self.profiling_info - processor_inputs = profiling.get_dummy_processor_inputs( - seq_len, mm_counts) - - return self.apply( - prompt_text=processor_inputs.prompt_text, - mm_data=processor_inputs.mm_data, - hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, - ) - - def get_dummy_data(self, seq_len: int) -> DummyData: - # Avoid circular import - from vllm.sequence import SequenceData - - profiling = self.profiling_info - mm_counts = profiling.get_mm_limits() - mm_max_tokens_per_item = profiling.get_mm_max_tokens_per_item(seq_len) - if mm_counts.keys() != mm_max_tokens_per_item.keys(): - raise AssertionError( - "The keys returned by `get_supported_mm_limits`" - f"({set(mm_counts.keys())}) should be the same as those " - "returned by `get_mm_max_tokens_per_item` " - f"({set(mm_max_tokens_per_item.keys())})") - - mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts) - prompt_token_ids = mm_inputs["prompt_token_ids"] - placeholders_by_modality = mm_inputs["mm_placeholders"] - - total_placeholders_by_modality = { - modality: sum(item["length"] for item in placeholders) - for modality, placeholders in placeholders_by_modality.items() - } - expected_placeholders_by_modality = { - modality: mm_max_tokens_per_item[modality] * mm_counts[modality] - for modality in placeholders_by_modality - } - if total_placeholders_by_modality != expected_placeholders_by_modality: - raise AssertionError( - f"The processed dummy data has a total of " - f"{total_placeholders_by_modality} placeholder tokens, which " - f"is not the expected {expected_placeholders_by_modality} " - "tokens.") - - total_len = len(prompt_token_ids) - - # V0 does not support chunked prefill. - if total_len > seq_len and not envs.VLLM_USE_V1: - logger.warning( - "The context length (%d) of the model is too short " - "to hold the multi-modal embeddings in the worst case " - "(%d tokens in total, out of which %s are reserved for " - "multi-modal embeddings). This may cause certain multi-modal " - "inputs to fail during inference, even when the input text is " - "short. To avoid this, you should increase `max_model_len`, " - "reduce `max_num_seqs`, and/or reduce `mm_counts`.", seq_len, - total_len, total_placeholders_by_modality) - - return DummyData( - seq_data=SequenceData.from_prompt_token_counts((0, seq_len)), - multi_modal_data=None, - multi_modal_placeholders=None, - ) - - prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids))) - - return DummyData( - seq_data=SequenceData.from_seqs(prompt_token_ids), - multi_modal_data=mm_inputs["mm_kwargs"], - multi_modal_placeholders=placeholders_by_modality, - ) + raise NotImplementedError diff --git a/vllm/multimodal/processor.py b/vllm/multimodal/processor.py new file mode 100644 index 0000000000000..aa509eb347f47 --- /dev/null +++ b/vllm/multimodal/processor.py @@ -0,0 +1,562 @@ +from abc import ABC, abstractmethod +from collections.abc import Mapping, Sequence +from typing import Any, Generic, Optional, TypeVar + +from transformers import BatchFeature + +from vllm import envs +from vllm.transformers_utils.tokenizer import decode_tokens, encode_tokens + +from .hasher import MultiModalHasher +from .inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalInputsV2, MultiModalKwargs, + MultiModalKwargsItem) +from .parse import MultiModalDataItems, MultiModalDataParser +from .processing import (BaseProcessingInfo, BoundPromptReplacement, + PlaceholderInfo, ProcessingCache, PromptReplacement, + find_mm_placeholders, find_text_matches, + find_token_matches, full_groupby_modality, + replace_text_matches, replace_token_matches) +from .profiling import BaseDummyDataBuilder + +_I = TypeVar("_I", bound=BaseProcessingInfo) + + +class BaseMultiModalProcessor(ABC, Generic[_I]): + """ + Abstract base class to process multi-modal inputs to be used in vLLM. + + Not to be confused with :class:`transformers.ProcessorMixin`. + """ + + def __init__(self, + info: _I, + dummy_data_builder: BaseDummyDataBuilder[_I], + *, + cache: Optional[ProcessingCache] = None, + enable_sanity_checks: bool = True) -> None: + super().__init__() + + self.info = info + self.dummy_data_builder = dummy_data_builder + self.cache = cache + self.enable_sanity_checks = enable_sanity_checks + + self.data_parser = self._get_data_parser() + + def __call__( + self, + prompt: str, + mm_data: MultiModalDataDict, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> MultiModalInputsV2: + return self.apply(prompt, mm_data, hf_processor_mm_kwargs) + + def _get_data_parser(self) -> MultiModalDataParser: + """ + Construct a parser to preprocess multi-modal data items + before passing them to :meth:`_get_hf_mm_data`. + + You can support additional modalities by creating a subclass + of :class:`MultiModalDataParser` that has additional subparsers. + """ + return MultiModalDataParser() + + def _to_mm_items( + self, + mm_data: MultiModalDataDict, + ) -> MultiModalDataItems: + """ + Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems` + before passing them to :meth:`_get_hf_mm_data`. + """ + mm_items = self.data_parser.parse_mm_data(mm_data) + + mm_limits = self.info.ctx.get_mm_config().limit_per_prompt + for modality, items in mm_items.items(): + limit = mm_limits.get(modality, 1) + if len(items) > limit: + raise ValueError( + f"You set {modality}={limit} (or defaulted to 1) in " + f"`--limit-mm-per-prompt`, but passed {len(items)} " + f"{modality} items in the same prompt.") + + return mm_items + + @abstractmethod + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + """Given the HF-processed data, output the metadata of each field.""" + raise NotImplementedError + + @abstractmethod + def _get_prompt_replacements( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> list[PromptReplacement]: + """ + Given the original multi-modal items for this modality + and HF-processed data, output the replacements to perform. + + Notes: + - You should not assume that HF processor always performs prompt + replacement: in :meth:`_apply_hf_processor_missing`, this method + is called on text-only and multimodal-only inputs separately, + instead of passing them in the same call. + - The replacement information returned by this method is also used + to determine the placeholder token positions for each multi-modal + item. + """ + raise NotImplementedError + + def _find_mm_placeholders( + self, + mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]], + new_token_ids: list[int], + mm_item_counts: Mapping[str, int], + ) -> Mapping[str, list[PlaceholderInfo]]: + return find_mm_placeholders(mm_prompt_repls, new_token_ids, + mm_item_counts) + + def _get_hf_mm_data( + self, + mm_items: MultiModalDataItems, + ) -> tuple[dict[str, Any], dict[str, Any]]: + processor_data = dict[str, Any]() + passthrough_data = dict[str, Any]() + + for items in mm_items.values(): + processor_data.update(items.get_processor_data()) + passthrough_data.update(items.get_passthrough_data()) + + return processor_data, passthrough_data + + def _call_hf_processor( + self, + prompt: str, + # Not to be confused with `mm_data` in `self.apply`. + # This refers to the data to be passed to HF processor. + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + """ + Call the HF processor on the prompt text and + associated multi-modal data. + """ + return self.info.ctx.call_hf_processor( + self.info.get_hf_processor(**mm_kwargs), + dict(text=prompt, **mm_data), + mm_kwargs, + ) + + def _apply_hf_processor( + self, + prompt_text: str, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> tuple[list[int], MultiModalKwargs]: + """ + Wrapper of :meth:`_call_hf_processor` that applies + additional pre-processing and post-processing. + """ + processor_data, passthrough_data = self._get_hf_mm_data(mm_items) + + processed_data = self._call_hf_processor( + prompt=prompt_text, + mm_data=processor_data, + mm_kwargs=hf_processor_mm_kwargs, + ) + processed_data.update(passthrough_data) + + prompt_ids, = processed_data.pop("input_ids").tolist() + + mm_kwargs = MultiModalKwargs.from_hf_inputs( + processed_data, + self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs), + ) + + return prompt_ids, mm_kwargs + + def _apply_hf_processor_missing( + self, + prompt_text: str, + mm_missing_data_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + ): + """ + Apply the HF processor on the full prompt text, but only on the + multi-modal data that are missing from the cache. + + Note: + We pass prompt text and multi-modal data into the HF processor + in separate calls to avoid HF prompt replacement being done for + cached items; instead, we rely on our own prompt replacement logic + (:meth:`_get_prompt_replacements`) for the full text. + """ + mm_missing_counts = mm_missing_data_items.get_all_counts() + + prompt_ids, _ = self._apply_hf_processor( + prompt_text=prompt_text, + mm_items=MultiModalDataItems({}), + hf_processor_mm_kwargs={}, + ) + + # Some HF processors (e.g. Qwen2-VL) expect corresponding + # multi-modal tokens to be in the prompt text + dummy_inputs = self.dummy_data_builder.get_dummy_processor_inputs( + self.info.ctx.model_config.max_model_len, + mm_missing_counts, + ) + + _, mm_missing_kwargs = self._apply_hf_processor( + prompt_text=dummy_inputs.prompt_text, + mm_items=mm_missing_data_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + ) + + return prompt_ids, mm_missing_kwargs + + def _cached_apply_hf_processor( + self, + prompt_text: str, + mm_data_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> tuple[list[int], MultiModalKwargs]: + """ + Apply the HF processor on the full prompt text, + caching the results and reusing cached results. + """ + cache = self.cache + model_id = self.info.model_id + + _, passthrough_data = self._get_hf_mm_data(mm_data_items) + if cache is None or passthrough_data: + return self._apply_hf_processor( + prompt_text=prompt_text, + mm_items=mm_data_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + ) + + mm_maybe_cached_kw_items = { + modality: [ + cache.get(model_id, modality, item, hf_processor_mm_kwargs) + for item in items + ] + for modality, items in mm_data_items.items() + } + + mm_missing_idxs = { + modality: + [idx for idx, item in enumerate(kw_items) if item is None] + for modality, kw_items in mm_maybe_cached_kw_items.items() + } + mm_missing_data = { + modality: [mm_data_items[modality][idx] for idx in idxs] + for modality, idxs in mm_missing_idxs.items() + } + mm_missing_data_items = self._to_mm_items(mm_missing_data) + + prompt_ids, mm_missing_kwargs = self._apply_hf_processor_missing( + prompt_text=prompt_text, + mm_missing_data_items=mm_missing_data_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + ) + + mm_missing_next_idx = { + modality: 0 + for modality in mm_missing_data_items + } + + merged_kw_items = list[MultiModalKwargsItem]() + for modality, kw_items in mm_maybe_cached_kw_items.items(): + for idx, kw_item in enumerate(kw_items): + if kw_item is None: + kw_item = mm_missing_kwargs.get_item( + modality, + mm_missing_next_idx[modality], + ) + + cache.put( + model_id, + modality, + mm_data_items[modality][idx], + hf_processor_mm_kwargs, + kw_item, + ) + + mm_missing_next_idx[modality] += 1 + + merged_kw_items.append(kw_item) + + if self.enable_sanity_checks: + mm_missing_counts = mm_missing_data_items.get_all_counts() + assert all( + item_count == mm_missing_counts[modality] + for modality, item_count in mm_missing_next_idx.items()), dict( + mm_missing_next_idx=mm_missing_next_idx, + mm_missing_counts=mm_missing_counts) + + mm_kwargs = MultiModalKwargs.from_items(merged_kw_items) + + return prompt_ids, mm_kwargs + + def _bind_and_group_repls( + self, + prompt_repls: list[PromptReplacement], + ) -> dict[str, list[BoundPromptReplacement]]: + tokenizer = self.info.get_tokenizer() + + it = (prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls) + return dict(full_groupby_modality(it)) + + def _always_apply_prompt_replacements(self) -> bool: + """ + A flag which can be overridden so that + :meth:`_apply_prompt_replacements` is always called even if we + detect that HF has performed processing via + :meth:`_find_placeholders_by_modality`. + + This is useful in cases where :meth:`_find_placeholders_by_modality` + cannot be reliably used to detect whether HF has performed processing. + """ + return False + + def _apply_prompt_replacements( + self, + token_ids: list[int], + mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]], + mm_item_counts: Mapping[str, int], + ) -> tuple[list[int], str, Mapping[str, list[PlaceholderInfo]]]: + tokenizer = self.info.get_tokenizer() + + mm_token_matches = { + modality: find_token_matches(token_ids, prompt_repls) + for modality, prompt_repls in mm_prompt_repls.items() + } + mm_match_counts = { + modality: len(matches) + for modality, matches in mm_token_matches.items() + } + + # If the search text does not represent a special token, + # it may have different token IDs in the prompt, because + # the tokens may go across the boundaries of the search text. + # ---- + # e.g. when searching for "foo" in "food", if "food" itself makes + # up a token, then the token ID of "foo" will not appear at all + # ---- + # Since it is inefficient to search for all possible tokenizations + # of the search text in the prompt, we instead perform string + # replacement on the decoded token IDs, then encode them back. + if all( + mm_match_counts.get(modality, 0) >= item_count + for modality, item_count in mm_item_counts.items() + ): # yapf: disable + token_ids = replace_token_matches( + token_ids, + mm_token_matches, + mm_item_counts, + ) + + text = decode_tokens(tokenizer, token_ids) + matched_repls = { + modality: [match.prompt_repl for match in token_matches] + for modality, token_matches in mm_token_matches.items() + } + else: + text = decode_tokens(tokenizer, token_ids) + + mm_text_matches = { + modality: find_text_matches(text, prompt_repls) + for modality, prompt_repls in mm_prompt_repls.items() + } + text = replace_text_matches( + text, + mm_text_matches, + mm_item_counts, + ) + + token_ids = encode_tokens(tokenizer, + text, + add_special_tokens=False) + matched_repls = { + modality: [match.prompt_repl for match in token_matches] + for modality, token_matches in mm_text_matches.items() + } + + placeholders = self._find_mm_placeholders( + matched_repls, + token_ids, + mm_item_counts, + ) + + return token_ids, text, placeholders + + def _validate_mm_kwargs( + self, + mm_kwargs: MultiModalKwargs, + mm_item_counts: Mapping[str, int], + ) -> None: + for modality, item_count in mm_item_counts.items(): + if modality in mm_kwargs.modalities: + items = mm_kwargs.get_items(modality) + else: + items = [] + + if len(items) != item_count: + raise RuntimeError( + f"Expected there to be {item_count} {modality} items in " + f"keyword arguments corresponding to {item_count} " + f"{modality} data items, but only found {len(items)}! " + "There is likely a problem with your " + "implementation of merged multi-modal processor for this " + "model (usually arising from an inconsistency between " + "`_call_hf_processor` and `_get_mm_fields_config`).") + + def _validate_mm_placeholders( + self, + mm_placeholders: Mapping[str, list[PlaceholderInfo]], + mm_item_counts: Mapping[str, int], + *, + allow_missing: bool = False, + ) -> Mapping[str, int]: + missing_repl_counts = dict[str, int]() + + for modality, item_count in mm_item_counts.items(): + placeholders = mm_placeholders.get(modality, []) + + if len(placeholders) != item_count and not allow_missing: + raise RuntimeError( + f"Expected there to be {item_count} prompt replacements " + f"corresponding to {item_count} {modality} items, but only " + f"found {len(placeholders)} prompt replacements! Either " + "the prompt text has missing/incorrect tokens for " + "multi-modal inputs, or there is a problem with your " + "implementation of merged multi-modal processor for this " + "model (usually arising from an inconsistency between " + "`_call_hf_processor` and `_get_prompt_replacements`).") + + missing_repl_counts[modality] = item_count - len(placeholders) + + return missing_repl_counts + + def apply( + self, + prompt_text: str, + mm_data: MultiModalDataDict, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> MultiModalInputsV2: + """ + Process multi-modal inputs to be used in vLLM. + + The main steps are: + + 1. Apply HF Processor on prompt text and multi-modal data together, + outputting token IDs and processed tensors. + 2. Find and replace sequences in the token IDs with placeholder tokens. + The number of placeholder tokens equals the feature size of the + multi-modal data outputted by the multi-modal encoder. + 3. Extract information about the placeholder tokens from the + processed token IDs. + """ + mm_items = self._to_mm_items(mm_data) + + # Create MM hashes (only used in V1) + # TODO: Use these hash keys for caching operations in apply_hf_processor + # instead of rehashing. + + if envs.VLLM_USE_V1: + model_id = self.info.model_id + mm_hashes = { + modality: [ + MultiModalHasher.hash_kwargs(model_id=model_id, + **{modality: item}, + **hf_processor_mm_kwargs) + for item in items + ] + for modality, items in mm_items.items() + } + else: + mm_hashes = None + + prompt_ids, mm_kwargs = self._cached_apply_hf_processor( + prompt_text, + mm_items, + hf_processor_mm_kwargs, + ) + + unbound_prompt_repls = self._get_prompt_replacements( + mm_items, + hf_processor_mm_kwargs, + mm_kwargs, + ) + mm_prompt_repls = self._bind_and_group_repls(unbound_prompt_repls) + + mm_item_counts = mm_items.get_all_counts() + self._validate_mm_kwargs(mm_kwargs, mm_item_counts) + + hf_mm_placeholders = self._find_mm_placeholders( + mm_prompt_repls, + prompt_ids, + mm_item_counts, + ) + + if self._always_apply_prompt_replacements(): + mm_missing_repl_counts = mm_item_counts + mm_missing_repls = dict(mm_prompt_repls) + else: + mm_missing_repl_counts = self._validate_mm_placeholders( + hf_mm_placeholders, + mm_item_counts, + allow_missing=True, + ) + + mm_missing_repls = dict[str, list[BoundPromptReplacement]]() + for modality, missing_repl_count in mm_missing_repl_counts.items(): + if missing_repl_count == 0: + mm_missing_repls[modality] = [] + elif missing_repl_count == mm_item_counts.get(modality, 0): + mm_missing_repls[modality] = mm_prompt_repls[modality] + else: + raise ValueError("Partial prompt replacement within " + f"{modality=} is not supported") + + # If HF processor already inserts placeholder tokens, + # there is no need for us to insert them + if all(len(repls) == 0 for repls in mm_missing_repls.items()): + tokenizer = self.info.get_tokenizer() + prompt_text = decode_tokens(tokenizer, prompt_ids) + mm_placeholders = hf_mm_placeholders + else: + ( + prompt_ids, + prompt_text, + missing_mm_placeholders, + ) = self._apply_prompt_replacements( + prompt_ids, + mm_missing_repls, + mm_missing_repl_counts, + ) + + mm_placeholders = {**hf_mm_placeholders, **missing_mm_placeholders} + + self._validate_mm_placeholders(mm_placeholders, mm_item_counts) + + mm_placeholder_ranges = { + modality: [item.to_range() for item in placeholders] + for modality, placeholders in mm_placeholders.items() + } + + return MultiModalInputsV2( + type="multimodal", + prompt=prompt_text, + prompt_token_ids=prompt_ids, + mm_kwargs=mm_kwargs, + mm_hashes=mm_hashes, + mm_placeholders=mm_placeholder_ranges, + ) diff --git a/vllm/multimodal/profiler.py b/vllm/multimodal/profiler.py new file mode 100644 index 0000000000000..8a7d03c8d3c41 --- /dev/null +++ b/vllm/multimodal/profiler.py @@ -0,0 +1,133 @@ +from collections.abc import Mapping +from typing import Generic, TypeVar + +from vllm import envs +from vllm.inputs import DummyData +from vllm.logger import init_logger + +from .inputs import MultiModalInputsV2 +from .processing import BaseProcessingInfo +from .processor import BaseMultiModalProcessor +from .profiling import BaseDummyDataBuilder + +logger = init_logger(__name__) + +_I = TypeVar("_I", bound=BaseProcessingInfo) + + +class MultiModalProfiler(Generic[_I]): + + def __init__( + self, + processor: BaseMultiModalProcessor[_I], + ) -> None: + super().__init__() + + self.processor = processor + + @property + def processing(self) -> BaseProcessingInfo: + return self.processor.info + + @property + def dummy_data_builder(self) -> BaseDummyDataBuilder[_I]: + return self.processor.dummy_data_builder + + def _get_mm_limits(self) -> Mapping[str, int]: + mm_config = self.processing.ctx.get_mm_config() + mm_limit_per_prompt = mm_config.limit_per_prompt + + supported_mm_limits = self.processing.get_supported_mm_limits() + + mm_limits = { + modality: mm_limit_per_prompt.get(modality, 1) + for modality in supported_mm_limits + } + + for modality, supported_limit in supported_mm_limits.items(): + limit = mm_limits[modality] + if supported_limit is not None and supported_limit < limit: + raise ValueError( + f"You set {modality}={limit} (or defaulted to 1) in " + f"`--limit-mm-per-prompt`, but this model only supports " + f"at most {supported_limit} {modality} items.") + + return mm_limits + + def _get_dummy_mm_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalInputsV2: + factory = self.dummy_data_builder + processor_inputs = factory.get_dummy_processor_inputs( + seq_len, mm_counts) + + return self.processor.apply( + prompt_text=processor_inputs.prompt_text, + mm_data=processor_inputs.mm_data, + hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, + ) + + def get_dummy_data(self, seq_len: int) -> DummyData: + # Avoid circular import + from vllm.sequence import SequenceData + + mm_counts = self._get_mm_limits() + + processing = self.processing + mm_max_tokens_per_item = processing.get_mm_max_tokens_per_item(seq_len) + + if mm_counts.keys() != mm_max_tokens_per_item.keys(): + raise AssertionError( + "The keys returned by `get_supported_mm_limits`" + f"({set(mm_counts.keys())}) should be the same as those " + "returned by `get_mm_max_tokens_per_item` " + f"({set(mm_max_tokens_per_item.keys())})") + + mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts) + prompt_token_ids = mm_inputs["prompt_token_ids"] + placeholders_by_modality = mm_inputs["mm_placeholders"] + + total_placeholders_by_modality = { + modality: sum(item["length"] for item in placeholders) + for modality, placeholders in placeholders_by_modality.items() + } + expected_placeholders_by_modality = { + modality: mm_max_tokens_per_item[modality] * mm_counts[modality] + for modality in placeholders_by_modality + } + if total_placeholders_by_modality != expected_placeholders_by_modality: + raise AssertionError( + f"The processed dummy data has a total of " + f"{total_placeholders_by_modality} placeholder tokens, which " + f"is not the expected {expected_placeholders_by_modality} " + "tokens.") + + total_len = len(prompt_token_ids) + + # V0 does not support chunked prefill. + if total_len > seq_len and not envs.VLLM_USE_V1: + logger.warning( + "The context length (%d) of the model is too short " + "to hold the multi-modal embeddings in the worst case " + "(%d tokens in total, out of which %s are reserved for " + "multi-modal embeddings). This may cause certain multi-modal " + "inputs to fail during inference, even when the input text is " + "short. To avoid this, you should increase `max_model_len`, " + "reduce `max_num_seqs`, and/or reduce `mm_counts`.", seq_len, + total_len, total_placeholders_by_modality) + + return DummyData( + seq_data=SequenceData.from_prompt_token_counts((0, seq_len)), + multi_modal_data=None, + multi_modal_placeholders=None, + ) + + prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids))) + + return DummyData( + seq_data=SequenceData.from_seqs(prompt_token_ids), + multi_modal_data=mm_inputs["mm_kwargs"], + multi_modal_placeholders=placeholders_by_modality, + ) diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index 2ecf0db1a485d..6ecf1d4f11061 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -1,18 +1,14 @@ from abc import ABC, abstractmethod from collections.abc import Mapping from dataclasses import dataclass, field -from typing import Optional +from typing import Generic, TypeVar import numpy as np import numpy.typing as npt from PIL import Image -from vllm.inputs import InputProcessingContext -from vllm.logger import init_logger - from .inputs import MultiModalDataDict - -logger = init_logger(__name__) +from .processing import BaseProcessingInfo @dataclass @@ -23,39 +19,19 @@ class ProcessorInputs: hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict) -class BaseProfilingInfo(ABC): +_I = TypeVar("_I", bound=BaseProcessingInfo) + + +class BaseDummyDataBuilder(ABC, Generic[_I]): """ - Abstract base class that provides the information necessary to profile + Abstract base class that constructs the dummy data to profile multi-modal models. """ - def __init__(self, ctx: InputProcessingContext) -> None: + def __init__(self, info: _I) -> None: super().__init__() - self.ctx = ctx - - @abstractmethod - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: - """ - Return the maximum supported number of items for each modality. - - A value of `None` means unlimited number of items. - - Omitting a modality from the returned dictionary means that - it is not supported at all. - """ - raise NotImplementedError - - @abstractmethod - def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: - """ - Get the maximum possible number of tokens per data item - for each modality. - - The dictionary returned by this method should have the same - keys as that returned by :meth:`get_supported_mm_limits`. - """ - raise NotImplementedError + self.info = info @abstractmethod def get_dummy_processor_inputs( @@ -64,8 +40,8 @@ def get_dummy_processor_inputs( mm_counts: Mapping[str, int], ) -> ProcessorInputs: """ - Build the multi-modal portion of the input which, after processing, - results in `mm_max_tokens` in :meth:`get_mm_max_tokens_per_item`. + Build the input which, after processing, results in + `self.info.get_mm_max_tokens_per_item()` placeholder tokens. """ raise NotImplementedError @@ -98,24 +74,3 @@ def _get_dummy_videos( ) -> list[npt.NDArray]: video = np.zeros((num_frames, width, height, 3)) return [video] * num_videos - - def get_mm_limits(self) -> Mapping[str, int]: - mm_config = self.ctx.get_mm_config() - mm_limit_per_prompt = mm_config.limit_per_prompt - - supported_mm_limits = self.get_supported_mm_limits() - - mm_limits = { - modality: mm_limit_per_prompt.get(modality, 1) - for modality in supported_mm_limits - } - - for modality, supported_limit in supported_mm_limits.items(): - limit = mm_limits[modality] - if supported_limit is not None and supported_limit < limit: - raise ValueError( - f"You set {modality}={limit} (or defaulted to 1) in " - f"`--limit-mm-per-prompt`, but this model only supports " - f"at most {supported_limit} {modality} items.") - - return mm_limits diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index f75a594a4c4e0..ac1bb8b33bff8 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -1,7 +1,8 @@ import functools from collections import UserDict -from typing import (TYPE_CHECKING, Any, Dict, Mapping, Optional, Protocol, - Sequence, Type, TypeVar) +from dataclasses import dataclass +from typing import (TYPE_CHECKING, Any, Dict, Generic, Mapping, Optional, + Protocol, Sequence, Type, TypeVar) import torch.nn as nn @@ -14,7 +15,9 @@ from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc from .image import ImagePlugin from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors -from .processing import BaseMultiModalProcessor, ProcessingCache +from .processing import ProcessingCache +from .processor import BaseMultiModalProcessor, BaseProcessingInfo +from .profiling import BaseDummyDataBuilder from .utils import cached_get_tokenizer from .video import VideoPlugin @@ -27,20 +30,57 @@ MM_CACHE_SIZE = 256 N = TypeVar("N", bound=Type[nn.Module]) +_I = TypeVar("_I", bound=BaseProcessingInfo) +_I_co = TypeVar("_I_co", bound=BaseProcessingInfo, covariant=True) -class MultiModalProcessorFactory(Protocol): +class ProcessingInfoFactory(Protocol[_I_co]): """Constructs a :class:`MultiModalProcessor` instance from the context.""" def __call__( self, ctx: InputProcessingContext, + ) -> _I_co: + ... + + +class DummyDataBuilderFactory(Protocol[_I]): + """Constructs a :class:`BaseDummyDataBuilder` instance from the context.""" + + def __call__(self, info: _I) -> BaseDummyDataBuilder[_I]: + ... + + +class MultiModalProcessorFactory(Protocol[_I]): + """Constructs a :class:`MultiModalProcessor` instance from the context.""" + + def __call__( + self, + info: _I, + dummy_data_builder: BaseDummyDataBuilder[_I], *, cache: Optional[ProcessingCache] = None, - ) -> BaseMultiModalProcessor: + ) -> BaseMultiModalProcessor[_I]: ... +@dataclass(frozen=True) +class _ProcessorFactories(Generic[_I]): + info: ProcessingInfoFactory[_I] + processor: MultiModalProcessorFactory[_I] + dummy_data: DummyDataBuilderFactory[_I] + + def build_processor( + self, + ctx: InputProcessingContext, + *, + cache: Optional[ProcessingCache] = None, + ): + info = self.info(ctx) + dummy_data_builder = self.dummy_data(info) + return self.processor(info, dummy_data_builder, cache=cache) + + class _MultiModalLimits(UserDict["ModelConfig", Dict[str, int]]): """ Wraps `_limits_by_model` for a more informative error message @@ -71,7 +111,7 @@ def __init__( self._plugins = {p.get_data_key(): p for p in plugins} self._processor_factories = ClassRegistry[nn.Module, - MultiModalProcessorFactory]() + _ProcessorFactories]() # This is used for non-multimodal models self._disabled_limits_per_plugin = {k: 0 for k in self._plugins} @@ -224,7 +264,7 @@ def get_max_tokens_per_item_by_modality( tokenizer = cached_get_tokenizer(model_config.tokenizer) processor = self.create_processor(model_config, tokenizer) seq_len = model_config.max_model_len - return processor.profiling_info.get_mm_max_tokens_per_item(seq_len) + return processor.info.get_mm_max_tokens_per_item(seq_len) return { key: plugin.get_max_multimodal_tokens(model_config) @@ -315,7 +355,10 @@ def get_mm_limits_per_prompt( def register_processor( self, - factory: MultiModalProcessorFactory, + processor: MultiModalProcessorFactory[_I], + *, + info: ProcessingInfoFactory[_I], + dummy_data: DummyDataBuilderFactory[_I], ): """ Register a multi-modal processor to a model class. The processor @@ -336,7 +379,11 @@ def wrapper(model_cls: N) -> N: "registered to %s. It is overwritten by the new one.", model_cls, self) - self._processor_factories[model_cls] = factory + self._processor_factories[model_cls] = _ProcessorFactories( + info=info, + dummy_data=dummy_data, + processor=processor, + ) return model_cls @@ -359,15 +406,15 @@ def create_processor( self, model_config: "ModelConfig", tokenizer: AnyTokenizer, - ) -> BaseMultiModalProcessor: + ) -> BaseMultiModalProcessor[BaseProcessingInfo]: """ Create a multi-modal processor for a specific model and tokenizer. """ model_cls = self._get_model_cls(model_config) - processor_factory = self._processor_factories[model_cls] + factories = self._processor_factories[model_cls] ctx = InputProcessingContext(model_config, tokenizer) cache = (None if model_config.disable_mm_preprocessor_cache else self._processing_cache) - return processor_factory(ctx, cache=cache) + return factories.build_processor(ctx, cache=cache) From 62942e3de526b3b9577df9be19322ebec3b7ddac Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 7 Jan 2025 17:06:01 +0000 Subject: [PATCH 2/7] Rename Signed-off-by: DarkLight1337 --- tests/multimodal/test_processing.py | 4 ++-- .../vllm_add_dummy_model/my_llava.py | 4 ++-- vllm/model_executor/models/aria.py | 6 ++--- vllm/model_executor/models/blip2.py | 6 ++--- vllm/model_executor/models/chameleon.py | 7 +++--- vllm/model_executor/models/fuyu.py | 6 ++--- vllm/model_executor/models/llava.py | 14 ++++++------ vllm/model_executor/models/llava_next.py | 4 ++-- .../model_executor/models/llava_next_video.py | 8 +++---- vllm/model_executor/models/llava_onevision.py | 13 +++++------ vllm/model_executor/models/phi3v.py | 6 ++--- vllm/model_executor/models/qwen2_audio.py | 8 +++---- vllm/model_executor/models/qwen2_vl.py | 6 ++--- vllm/model_executor/models/ultravox.py | 7 +++--- vllm/multimodal/processor.py | 8 +++---- vllm/multimodal/profiler.py | 8 +++---- vllm/multimodal/profiling.py | 2 +- vllm/multimodal/registry.py | 22 ++++++++++--------- 18 files changed, 71 insertions(+), 68 deletions(-) diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index 003cc177c57c4..808e82e05cfcf 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -690,7 +690,7 @@ def _test_processing_cache_correctness( baseline_processor = factories.build_processor(ctx, cache=None) cached_processor = factories.build_processor(ctx, cache=cache) - dummy_data_builder = baseline_processor.dummy_data_builder + dummy_inputs = baseline_processor.dummy_inputs rng = np.random.RandomState(0) @@ -722,7 +722,7 @@ def _test_processing_cache_correctness( } mm_counts = {k: len(vs) for k, vs in mm_data.items()} - prompt = dummy_data_builder.get_dummy_processor_inputs( + prompt = dummy_inputs.get_dummy_processor_inputs( model_config.max_model_len, mm_counts, ).prompt_text diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py index e273c4cbf2ea2..d07560c2a9b64 100644 --- a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py @@ -2,7 +2,7 @@ import torch -from vllm.model_executor.models.llava import (LlavaDummyDataBuilder, +from vllm.model_executor.models.llava import (LlavaDummyInputsBuilder, LlavaForConditionalGeneration, LlavaMultiModalProcessor, LlavaProcessingInfo) @@ -12,7 +12,7 @@ @MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor, info=LlavaProcessingInfo, - dummy_data=LlavaDummyDataBuilder) + dummy=LlavaDummyInputsBuilder) class MyLlava(LlavaForConditionalGeneration): def compute_logits( diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index 88cf73d109ee2..ede88d1a867cb 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -26,7 +26,7 @@ from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import BaseProcessingInfo, PromptReplacement from vllm.multimodal.processor import BaseMultiModalProcessor -from vllm.multimodal.profiling import BaseDummyDataBuilder, ProcessorInputs +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.aria import (AriaMoELMConfig, AriaVisionConfig) @@ -464,7 +464,7 @@ def get_num_image_tokens(self) -> int: return max(hf_config.projector_patch_to_query_dict.values()) -class AriaDummyDataBuilder(BaseDummyDataBuilder[AriaProcessingInfo]): +class AriaDummyInputsBuilder(BaseDummyInputsBuilder[AriaProcessingInfo]): def get_dummy_processor_inputs( self, @@ -526,7 +526,7 @@ def _get_prompt_replacements( @MULTIMODAL_REGISTRY.register_processor(AriaMultiModalProcessor, info=AriaProcessingInfo, - dummy_data=AriaDummyDataBuilder) + dummy=AriaDummyInputsBuilder) class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): """ Aria model for conditional generation tasks. diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 5db1af556ce92..cbdcce6e07b9f 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -20,7 +20,7 @@ from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import BaseProcessingInfo, PromptReplacement from vllm.multimodal.processor import BaseMultiModalProcessor -from vllm.multimodal.profiling import BaseDummyDataBuilder, ProcessorInputs +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from .blip import BlipVisionModel @@ -413,7 +413,7 @@ def get_num_image_tokens(self) -> int: return hf_config.num_query_tokens -class Blip2DummyDataBuilder(BaseDummyDataBuilder[Blip2ProcessingInfo]): +class Blip2DummyInputsBuilder(BaseDummyInputsBuilder[Blip2ProcessingInfo]): def get_dummy_processor_inputs( self, @@ -490,7 +490,7 @@ def apply( @MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor, info=Blip2ProcessingInfo, - dummy_data=Blip2DummyDataBuilder) + dummy=Blip2DummyInputsBuilder) class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 29cb489d58a2e..bb01b372d17d5 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -33,7 +33,7 @@ from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import BaseProcessingInfo, PromptReplacement from vllm.multimodal.processor import BaseMultiModalProcessor -from vllm.multimodal.profiling import BaseDummyDataBuilder, ProcessorInputs +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.utils import print_warning_once @@ -68,7 +68,8 @@ def get_num_image_tokens(self) -> int: return processor.image_seq_length -class ChameleonDummyDataBuilder(BaseDummyDataBuilder[ChameleonProcessingInfo]): +class ChameleonDummyInputsBuilder( + BaseDummyInputsBuilder[ChameleonProcessingInfo]): def get_dummy_processor_inputs( self, @@ -915,7 +916,7 @@ def forward( @MULTIMODAL_REGISTRY.register_processor(ChameleonMultiModalProcessor, info=ChameleonProcessingInfo, - dummy_data=ChameleonDummyDataBuilder) + dummy=ChameleonDummyInputsBuilder) class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 972d47c1633c8..66e68e3f8f541 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -37,7 +37,7 @@ MultiModalDataItems) from vllm.multimodal.processing import BaseProcessingInfo, PromptReplacement from vllm.multimodal.processor import BaseMultiModalProcessor -from vllm.multimodal.profiling import BaseDummyDataBuilder, ProcessorInputs +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from .interfaces import SupportsMultiModal, SupportsPP @@ -117,7 +117,7 @@ def get_image_size_with_most_features(self) -> ImageSize: height=image_processor.size["height"]) -class FuyuDummyDataBuilder(BaseDummyDataBuilder[FuyuProcessingInfo]): +class FuyuDummyInputsBuilder(BaseDummyInputsBuilder[FuyuProcessingInfo]): def get_dummy_processor_inputs( self, @@ -244,7 +244,7 @@ def apply( @MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor, info=FuyuProcessingInfo, - dummy_data=FuyuDummyDataBuilder) + dummy=FuyuDummyInputsBuilder) class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index a2521c7e8514e..b0b0073280c12 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -29,7 +29,7 @@ from vllm.multimodal.processing import (BaseProcessingInfo, ProcessingCache, PromptReplacement) from vllm.multimodal.processor import BaseMultiModalProcessor -from vllm.multimodal.profiling import BaseDummyDataBuilder, ProcessorInputs +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from .clip import CLIPVisionModel @@ -170,7 +170,7 @@ def get_max_image_tokens(self) -> int: _I = TypeVar("_I", bound=BaseLlavaProcessingInfo) -class LlavaDummyDataBuilder(BaseDummyDataBuilder[_I]): +class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]): def get_dummy_processor_inputs( self, @@ -360,7 +360,7 @@ def _build_llava_or_pixtral_hf_info( def _build_llava_or_pixtral_hf_processor( info: _I, - dummy_data_builder: BaseDummyDataBuilder[_I], + dummy_inputs: BaseDummyInputsBuilder[_I], *, cache: Optional[ProcessingCache] = None, enable_sanity_checks: bool = True, @@ -368,7 +368,7 @@ def _build_llava_or_pixtral_hf_processor( if isinstance(info, PixtralHFProcessingInfo): return PixtralHFMultiModalProcessor( info, - dummy_data_builder, # type: ignore + dummy_inputs, # type: ignore cache=cache, enable_sanity_checks=enable_sanity_checks, ) @@ -376,7 +376,7 @@ def _build_llava_or_pixtral_hf_processor( if isinstance(info, LlavaProcessingInfo): return LlavaMultiModalProcessor( info, - dummy_data_builder, # type: ignore + dummy_inputs, # type: ignore cache=cache, enable_sanity_checks=enable_sanity_checks, ) @@ -461,7 +461,7 @@ def init_vision_tower_for_llava( @MULTIMODAL_REGISTRY.register_processor(_build_llava_or_pixtral_hf_processor, info=_build_llava_or_pixtral_hf_info, - dummy_data=LlavaDummyDataBuilder) + dummy=LlavaDummyInputsBuilder) class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): # BitandBytes specific attributes bitsandbytes_stacked_params_mapping = { @@ -793,6 +793,6 @@ def get_replacement_mantis(item_idx: int): # `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` @MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor, info=LlavaProcessingInfo, - dummy_data=LlavaDummyDataBuilder) + dummy=LlavaDummyInputsBuilder) class MantisForConditionalGeneration(LlavaForConditionalGeneration): pass diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 6af8acd392e5e..c76822f657661 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -21,7 +21,7 @@ from .clip import CLIPVisionModel from .interfaces import SupportsMultiModal, SupportsPP from .llava import (BaseLlavaMultiModalProcessor, BaseLlavaProcessingInfo, - LlavaDummyDataBuilder, LlavaLikeConfig, + LlavaDummyInputsBuilder, LlavaLikeConfig, LlavaMultiModalProjector, init_vision_tower_for_llava) from .siglip import SiglipVisionModel from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn, @@ -181,7 +181,7 @@ class LlavaNextMultiModalProcessor( @MULTIMODAL_REGISTRY.register_processor(LlavaNextMultiModalProcessor, info=LlavaNextProcessingInfo, - dummy_data=LlavaDummyDataBuilder) + dummy=LlavaDummyInputsBuilder) class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index 881d71ed9ab5c..8e60aa8aaa719 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -21,7 +21,7 @@ VideoEmbeddingItems, VideoProcessorItems) from vllm.multimodal.processing import BaseProcessingInfo, PromptReplacement from vllm.multimodal.processor import BaseMultiModalProcessor -from vllm.multimodal.profiling import BaseDummyDataBuilder, ProcessorInputs +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of @@ -134,8 +134,8 @@ def get_max_num_frames(self, seq_len: int) -> int: return max(max_total_frames // max(max_videos, 1), 1) -class LlavaNextVideoDummyDataBuilder( - BaseDummyDataBuilder[LlavaNextVideoProcessingInfo]): +class LlavaNextVideoDummyInputsBuilder( + BaseDummyInputsBuilder[LlavaNextVideoProcessingInfo]): def get_dummy_processor_inputs( self, @@ -269,7 +269,7 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor: @MULTIMODAL_REGISTRY.register_processor( LlavaNextVideoMultiModalProcessor, info=LlavaNextVideoProcessingInfo, - dummy_data=LlavaNextVideoDummyDataBuilder, + dummy=LlavaNextVideoDummyInputsBuilder, ) class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index 6622e3a150e64..21f589f4fbf1c 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -28,7 +28,7 @@ from .clip import CLIPVisionModel from .interfaces import SupportsMultiModal, SupportsPP -from .llava import LlavaDummyDataBuilder, init_vision_tower_for_llava +from .llava import LlavaDummyInputsBuilder, init_vision_tower_for_llava from .llava_next import (BaseLlavaNextMultiModalProcessor, LlavaNextLikeConfig, LlavaNextProcessingInfo) from .siglip import SiglipVisionModel @@ -233,8 +233,8 @@ def get_max_video_tokens(self, seq_len: int) -> int: ) -class LlavaOnevisionDummyDataBuilder( - LlavaDummyDataBuilder[LlavaOnevisionProcessingInfo]): +class LlavaOnevisionDummyInputsBuilder( + LlavaDummyInputsBuilder[LlavaOnevisionProcessingInfo]): def get_dummy_processor_inputs( self, @@ -392,10 +392,9 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor: return hidden_states -@MULTIMODAL_REGISTRY.register_processor( - LlavaOnevisionMultiModalProcessor, - info=LlavaOnevisionProcessingInfo, - dummy_data=LlavaOnevisionDummyDataBuilder) +@MULTIMODAL_REGISTRY.register_processor(LlavaOnevisionMultiModalProcessor, + info=LlavaOnevisionProcessingInfo, + dummy=LlavaOnevisionDummyInputsBuilder) class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 387fccecbf848..dab80c581eccb 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -39,7 +39,7 @@ BoundPromptReplacement, PlaceholderInfo, PromptReplacement) from vllm.multimodal.processor import BaseMultiModalProcessor -from vllm.multimodal.profiling import BaseDummyDataBuilder, ProcessorInputs +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of @@ -344,7 +344,7 @@ def get_image_size_with_most_features(self) -> ImageSize: return ImageSize(height=8000, width=50) -class Phi3VDummyDataBuilder(BaseDummyDataBuilder[Phi3VProcessingInfo]): +class Phi3VDummyInputsBuilder(BaseDummyInputsBuilder[Phi3VProcessingInfo]): def get_dummy_processor_inputs( self, @@ -498,7 +498,7 @@ def apply( @MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor, info=Phi3VProcessingInfo, - dummy_data=Phi3VDummyDataBuilder) + dummy=Phi3VDummyInputsBuilder) class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index aea7dc8dd8fea..bc4426eff5866 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -42,7 +42,7 @@ MultiModalDataParser) from vllm.multimodal.processing import BaseProcessingInfo, PromptReplacement from vllm.multimodal.processor import BaseMultiModalProcessor -from vllm.multimodal.profiling import BaseDummyDataBuilder, ProcessorInputs +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from .interfaces import SupportsMultiModal, SupportsPP @@ -115,8 +115,8 @@ def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: return {"audio": max_output_lengths} -class Qwen2AudioDummyDataBuilder(BaseDummyDataBuilder[Qwen2AudioProcessingInfo] - ): +class Qwen2AudioDummyInputsBuilder( + BaseDummyInputsBuilder[Qwen2AudioProcessingInfo]): def get_dummy_processor_inputs( self, @@ -237,7 +237,7 @@ def _always_apply_prompt_replacements(self) -> bool: @MULTIMODAL_REGISTRY.register_processor(Qwen2AudioMultiModalProcessor, info=Qwen2AudioProcessingInfo, - dummy_data=Qwen2AudioDummyDataBuilder) + dummy=Qwen2AudioDummyInputsBuilder) class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index d60656d140bf8..ba3b7127a0ee4 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -60,7 +60,7 @@ MultiModalDataItems, MultiModalDataParser) from vllm.multimodal.processing import BaseProcessingInfo, PromptReplacement from vllm.multimodal.processor import BaseMultiModalProcessor -from vllm.multimodal.profiling import BaseDummyDataBuilder, ProcessorInputs +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import uses_mrope @@ -883,7 +883,7 @@ def get_max_video_tokens(self, seq_len: int) -> int: ) -class Qwen2VLDummyDataBuilder(BaseDummyDataBuilder[Qwen2VLProcessingInfo]): +class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]): def get_dummy_processor_inputs( self, @@ -991,7 +991,7 @@ def _get_mm_fields_config( @MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor, info=Qwen2VLProcessingInfo, - dummy_data=Qwen2VLDummyDataBuilder) + dummy=Qwen2VLDummyInputsBuilder) class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP): packed_modules_mapping = { diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 1a403941f803f..477ed17905589 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -27,7 +27,7 @@ from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser from vllm.multimodal.processing import BaseProcessingInfo, PromptReplacement from vllm.multimodal.processor import BaseMultiModalProcessor -from vllm.multimodal.profiling import BaseDummyDataBuilder, ProcessorInputs +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.ultravox import UltravoxConfig @@ -98,7 +98,8 @@ def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: return {"audio": max_audio_tokens} -class UltravoxDummyDataBuilder(BaseDummyDataBuilder[UltravoxProcessingInfo]): +class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo] + ): def get_dummy_processor_inputs( self, @@ -340,7 +341,7 @@ def forward( @MULTIMODAL_REGISTRY.register_processor(UltravoxMultiModalProcessor, info=UltravoxProcessingInfo, - dummy_data=UltravoxDummyDataBuilder) + dummy=UltravoxDummyInputsBuilder) class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP): hf_to_vllm_mapper = WeightsMapper( diff --git a/vllm/multimodal/processor.py b/vllm/multimodal/processor.py index aa509eb347f47..08e6b7337e62d 100644 --- a/vllm/multimodal/processor.py +++ b/vllm/multimodal/processor.py @@ -17,7 +17,7 @@ find_mm_placeholders, find_text_matches, find_token_matches, full_groupby_modality, replace_text_matches, replace_token_matches) -from .profiling import BaseDummyDataBuilder +from .profiling import BaseDummyInputsBuilder _I = TypeVar("_I", bound=BaseProcessingInfo) @@ -31,14 +31,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): def __init__(self, info: _I, - dummy_data_builder: BaseDummyDataBuilder[_I], + dummy_inputs: BaseDummyInputsBuilder[_I], *, cache: Optional[ProcessingCache] = None, enable_sanity_checks: bool = True) -> None: super().__init__() self.info = info - self.dummy_data_builder = dummy_data_builder + self.dummy_inputs = dummy_inputs self.cache = cache self.enable_sanity_checks = enable_sanity_checks @@ -208,7 +208,7 @@ def _apply_hf_processor_missing( # Some HF processors (e.g. Qwen2-VL) expect corresponding # multi-modal tokens to be in the prompt text - dummy_inputs = self.dummy_data_builder.get_dummy_processor_inputs( + dummy_inputs = self.dummy_inputs.get_dummy_processor_inputs( self.info.ctx.model_config.max_model_len, mm_missing_counts, ) diff --git a/vllm/multimodal/profiler.py b/vllm/multimodal/profiler.py index 8a7d03c8d3c41..ed3c4edcf092d 100644 --- a/vllm/multimodal/profiler.py +++ b/vllm/multimodal/profiler.py @@ -8,7 +8,7 @@ from .inputs import MultiModalInputsV2 from .processing import BaseProcessingInfo from .processor import BaseMultiModalProcessor -from .profiling import BaseDummyDataBuilder +from .profiling import BaseDummyInputsBuilder logger = init_logger(__name__) @@ -30,8 +30,8 @@ def processing(self) -> BaseProcessingInfo: return self.processor.info @property - def dummy_data_builder(self) -> BaseDummyDataBuilder[_I]: - return self.processor.dummy_data_builder + def dummy_inputs(self) -> BaseDummyInputsBuilder[_I]: + return self.processor.dummy_inputs def _get_mm_limits(self) -> Mapping[str, int]: mm_config = self.processing.ctx.get_mm_config() @@ -59,7 +59,7 @@ def _get_dummy_mm_inputs( seq_len: int, mm_counts: Mapping[str, int], ) -> MultiModalInputsV2: - factory = self.dummy_data_builder + factory = self.dummy_inputs processor_inputs = factory.get_dummy_processor_inputs( seq_len, mm_counts) diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index 6ecf1d4f11061..de68b335cebd8 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -22,7 +22,7 @@ class ProcessorInputs: _I = TypeVar("_I", bound=BaseProcessingInfo) -class BaseDummyDataBuilder(ABC, Generic[_I]): +class BaseDummyInputsBuilder(ABC, Generic[_I]): """ Abstract base class that constructs the dummy data to profile multi-modal models. diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index ac1bb8b33bff8..3639384d505f7 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -17,7 +17,7 @@ from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors from .processing import ProcessingCache from .processor import BaseMultiModalProcessor, BaseProcessingInfo -from .profiling import BaseDummyDataBuilder +from .profiling import BaseDummyInputsBuilder from .utils import cached_get_tokenizer from .video import VideoPlugin @@ -44,10 +44,12 @@ def __call__( ... -class DummyDataBuilderFactory(Protocol[_I]): - """Constructs a :class:`BaseDummyDataBuilder` instance from the context.""" +class DummyInputsBuilderFactory(Protocol[_I]): + """ + Constructs a :class:`BaseDummyInputsBuilder` instance from the context. + """ - def __call__(self, info: _I) -> BaseDummyDataBuilder[_I]: + def __call__(self, info: _I) -> BaseDummyInputsBuilder[_I]: ... @@ -57,7 +59,7 @@ class MultiModalProcessorFactory(Protocol[_I]): def __call__( self, info: _I, - dummy_data_builder: BaseDummyDataBuilder[_I], + dummy_inputs: BaseDummyInputsBuilder[_I], *, cache: Optional[ProcessingCache] = None, ) -> BaseMultiModalProcessor[_I]: @@ -68,7 +70,7 @@ def __call__( class _ProcessorFactories(Generic[_I]): info: ProcessingInfoFactory[_I] processor: MultiModalProcessorFactory[_I] - dummy_data: DummyDataBuilderFactory[_I] + dummy: DummyInputsBuilderFactory[_I] def build_processor( self, @@ -77,8 +79,8 @@ def build_processor( cache: Optional[ProcessingCache] = None, ): info = self.info(ctx) - dummy_data_builder = self.dummy_data(info) - return self.processor(info, dummy_data_builder, cache=cache) + dummy_inputs_builder = self.dummy(info) + return self.processor(info, dummy_inputs_builder, cache=cache) class _MultiModalLimits(UserDict["ModelConfig", Dict[str, int]]): @@ -358,7 +360,7 @@ def register_processor( processor: MultiModalProcessorFactory[_I], *, info: ProcessingInfoFactory[_I], - dummy_data: DummyDataBuilderFactory[_I], + dummy: DummyInputsBuilderFactory[_I], ): """ Register a multi-modal processor to a model class. The processor @@ -381,7 +383,7 @@ def wrapper(model_cls: N) -> N: self._processor_factories[model_cls] = _ProcessorFactories( info=info, - dummy_data=dummy_data, + dummy=dummy, processor=processor, ) From b7e5324815aa48a00d2769bb37e0f695111a239e Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 7 Jan 2025 17:18:34 +0000 Subject: [PATCH 3/7] Cleanup Signed-off-by: DarkLight1337 --- vllm/model_executor/models/llava.py | 4 +- vllm/model_executor/models/llava_next.py | 25 +++++++---- .../model_executor/models/llava_next_video.py | 17 ++++---- vllm/model_executor/models/llava_onevision.py | 41 +++++++------------ vllm/model_executor/models/qwen2_vl.py | 15 ++++--- 5 files changed, 52 insertions(+), 50 deletions(-) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index b0b0073280c12..b42df56d7e589 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -123,7 +123,7 @@ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: return {"image": self.get_max_image_tokens()} - def apply_feature_select_strategy( + def _apply_feature_select_strategy( self, strategy: str, encoder_num_image_tokens: int, @@ -145,7 +145,7 @@ def get_num_image_tokens( hf_config = self.get_hf_config() vision_encoder_info = self.get_vision_encoder_info() - return self.apply_feature_select_strategy( + return self._apply_feature_select_strategy( hf_config.vision_feature_select_strategy, vision_encoder_info.get_num_image_tokens( image_width=image_width, diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index c76822f657661..52560a19a95bb 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -1,3 +1,4 @@ +from abc import abstractmethod from functools import cached_property from typing import (Final, Iterable, List, Literal, Mapping, Optional, Protocol, Set, Tuple, TypedDict, TypeVar, Union) @@ -82,7 +83,7 @@ def get_num_image_tokens( hf_config = self.get_hf_config() vision_encoder_info = self.get_vision_encoder_info() - base_feature_size = self.apply_feature_select_strategy( + base_feature_size = self._apply_feature_select_strategy( hf_config.vision_feature_select_strategy, vision_encoder_info.get_num_image_tokens( image_width=image_width, @@ -99,7 +100,7 @@ def get_num_image_tokens( ( unpadded_feature_size, newline_feature_size, - ) = self.get_num_unpadded_features( + ) = self._get_num_unpadded_features( original_height=image_height, original_width=image_width, npatches=vision_encoder_info.get_patch_grid_length(), @@ -110,7 +111,7 @@ def get_num_image_tokens( return unpadded_feature_size + newline_feature_size + base_feature_size # Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86 - def get_num_unpadded_features( + def _get_num_unpadded_features( self, *, original_height: int, @@ -162,6 +163,19 @@ def get_image_size_with_most_features(self) -> ImageSize: class BaseLlavaNextMultiModalProcessor(BaseLlavaMultiModalProcessor[_I]): + # Copied from BaseMultiModalProcessor + @abstractmethod + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + raise NotImplementedError + + +class LlavaNextMultiModalProcessor( + BaseLlavaNextMultiModalProcessor[LlavaNextProcessingInfo]): + def _get_mm_fields_config( self, hf_inputs: BatchFeature, @@ -174,11 +188,6 @@ def _get_mm_fields_config( ) -class LlavaNextMultiModalProcessor( - BaseLlavaNextMultiModalProcessor[LlavaNextProcessingInfo]): - pass - - @MULTIMODAL_REGISTRY.register_processor(LlavaNextMultiModalProcessor, info=LlavaNextProcessingInfo, dummy=LlavaDummyInputsBuilder) diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index 8e60aa8aaa719..cdb5b49994005 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -66,7 +66,7 @@ def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: max_video_tokens = self.get_num_video_tokens( image_width=target_width, image_height=target_height, - num_frames=self.get_max_num_frames(seq_len), + num_frames=self.get_num_frames_with_most_features(seq_len), ) return {"video": max_video_tokens} @@ -76,7 +76,7 @@ def get_image_size_with_most_features(self) -> ImageSize: width = height = vision_encoder_info.get_image_size() return ImageSize(width=width, height=height) - def get_num_frame_tokens( + def _get_num_frame_tokens( self, *, image_width: int, @@ -98,14 +98,14 @@ def get_num_video_tokens( image_height: int, num_frames: int, ) -> int: - num_frame_tokens = self.get_num_frame_tokens( + num_frame_tokens = self._get_num_frame_tokens( image_width=image_width, image_height=image_height, ) return num_frame_tokens * num_frames - def get_max_video_frames(self, max_tokens: int) -> int: + def _get_max_video_frames(self, max_tokens: int) -> int: target_width, target_height = self.get_image_size_with_most_features() num_frames = 0 @@ -125,11 +125,11 @@ def get_max_video_frames(self, max_tokens: int) -> int: return num_frames - def get_max_num_frames(self, seq_len: int) -> int: + def get_num_frames_with_most_features(self, seq_len: int) -> int: mm_config = self.ctx.get_mm_config() max_videos = mm_config.limit_per_prompt.get("video", 1) - max_total_frames = self.get_max_video_frames(seq_len) + max_total_frames = self._get_max_video_frames(seq_len) return max(max_total_frames // max(max_videos, 1), 1) @@ -146,15 +146,18 @@ def get_dummy_processor_inputs( processor = self.info.get_hf_processor() video_token = processor.video_token + target_width, target_height = \ self.info.get_image_size_with_most_features() + target_num_frames = \ + self.info.get_num_frames_with_most_features(seq_len) mm_data = { "video": self._get_dummy_videos( width=target_width, height=target_height, - num_frames=self.info.get_max_num_frames(seq_len), + num_frames=target_num_frames, num_videos=num_videos, ) } diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index 21f589f4fbf1c..f3c7843089856 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -19,8 +19,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, NestedTensors) -from vllm.multimodal.parse import (ImageSize, MultiModalDataItems, - VideoEmbeddingItems, VideoProcessorItems) +from vllm.multimodal.parse import (MultiModalDataItems, VideoEmbeddingItems, + VideoProcessorItems) from vllm.multimodal.processing import PromptReplacement from vllm.multimodal.profiling import ProcessorInputs from vllm.sequence import IntermediateTensors @@ -109,7 +109,7 @@ def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: # Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86 # with additional logic afterwards taken from LlavaOnevisionProcessor - def get_num_unpadded_features( + def _get_num_unpadded_features( self, *, original_height: int, @@ -145,23 +145,7 @@ def get_num_unpadded_features( return (unpadded_features, newline_features) - def get_image_size_with_most_features(self) -> ImageSize: - hf_config = self.get_hf_config() - largest_feature_size, largest_feature_pinpoint = 0, None - for (height, width) in hf_config.image_grid_pinpoints: - feat_size = self.get_num_image_tokens(image_width=width, - image_height=height) - if feat_size > largest_feature_size: - largest_feature_size = feat_size - largest_feature_pinpoint = ImageSize(width=width, - height=height) - - if largest_feature_size == 0 or largest_feature_pinpoint is None: - raise ValueError("Cannot have a largest feature size of 0!") - - return largest_feature_pinpoint - - def get_num_frame_tokens( + def _get_num_frame_tokens( self, *, image_width: int, @@ -183,14 +167,14 @@ def get_num_video_tokens( image_height: int, num_frames: int, ) -> int: - num_frame_tokens = self.get_num_frame_tokens( + num_frame_tokens = self._get_num_frame_tokens( image_width=image_width, image_height=image_height, ) return num_frame_tokens * num_frames + 1 # Newline token - def get_max_video_frames(self, max_tokens: int) -> int: + def _get_max_video_frames(self, max_tokens: int) -> int: target_width, target_height = self.get_image_size_with_most_features() num_frames = 0 @@ -210,14 +194,14 @@ def get_max_video_frames(self, max_tokens: int) -> int: return num_frames - def get_max_num_frames(self, seq_len: int) -> int: + def get_num_frames_with_most_features(self, seq_len: int) -> int: mm_config = self.ctx.get_mm_config() max_images = mm_config.limit_per_prompt.get("image", 1) max_videos = mm_config.limit_per_prompt.get("video", 1) max_image_tokens = self.get_max_image_tokens() * max_images - max_total_frames = self.get_max_video_frames(seq_len - - max_image_tokens) + max_total_frames = self._get_max_video_frames(seq_len - + max_image_tokens) max_frames_per_video = min(max_total_frames // max(max_videos, 1), _MAX_FRAMES_PER_VIDEO) @@ -229,7 +213,7 @@ def get_max_video_tokens(self, seq_len: int) -> int: return self.get_num_video_tokens( image_width=target_width, image_height=target_height, - num_frames=self.get_max_num_frames(seq_len), + num_frames=self.get_num_frames_with_most_features(seq_len), ) @@ -247,8 +231,11 @@ def get_dummy_processor_inputs( processor = self.info.get_hf_processor() image_token = processor.image_token video_token = processor.video_token + target_width, target_height = \ self.info.get_image_size_with_most_features() + target_num_frames = \ + self.info.get_num_frames_with_most_features(seq_len) mm_data = { "image": @@ -259,7 +246,7 @@ def get_dummy_processor_inputs( self._get_dummy_videos( width=target_width, height=target_height, - num_frames=self.info.get_max_num_frames(seq_len), + num_frames=target_num_frames, num_videos=num_videos, ) } diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index ba3b7127a0ee4..0701dfe6d7c0a 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -836,7 +836,7 @@ def get_max_image_tokens(self) -> int: image_height=target_height, ) - def get_max_video_frames(self, max_tokens: int) -> int: + def _get_max_video_frames(self, max_tokens: int) -> int: target_width, target_height = self.get_image_size_with_most_features() num_frames = 0 @@ -856,14 +856,14 @@ def get_max_video_frames(self, max_tokens: int) -> int: return num_frames - def get_max_num_frames(self, seq_len: int) -> int: + def get_num_frames_with_most_features(self, seq_len: int) -> int: mm_config = self.ctx.get_mm_config() max_images = mm_config.limit_per_prompt.get("image", 1) max_videos = mm_config.limit_per_prompt.get("video", 1) max_image_tokens = self.get_max_image_tokens() * max_images - max_total_frames = self.get_max_video_frames(seq_len - - max_image_tokens) + max_total_frames = self._get_max_video_frames(seq_len - + max_image_tokens) num_frames = max(max_total_frames // max(max_videos, 1), 1) @@ -879,7 +879,7 @@ def get_max_video_tokens(self, seq_len: int) -> int: return self.get_num_video_tokens( image_width=target_width, image_height=target_height, - num_frames=self.get_max_num_frames(seq_len), + num_frames=self.get_num_frames_with_most_features(seq_len), ) @@ -896,8 +896,11 @@ def get_dummy_processor_inputs( hf_processor = self.info.get_hf_processor() image_token: str = hf_processor.image_token video_token: str = hf_processor.video_token + target_width, target_height = \ self.info.get_image_size_with_most_features() + target_num_frames = \ + self.info.get_num_frames_with_most_features(seq_len) mm_data = { "image": @@ -908,7 +911,7 @@ def get_dummy_processor_inputs( self._get_dummy_videos( width=target_width, height=target_height, - num_frames=self.info.get_max_num_frames(seq_len), + num_frames=target_num_frames, num_videos=num_videos, ) } From 23b637b83f19ea295913ec4d6dad326d6b74e59d Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 8 Jan 2025 02:43:59 +0000 Subject: [PATCH 4/7] Rename `dummy -> dummy_inputs` Signed-off-by: DarkLight1337 --- .../vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py | 2 +- vllm/model_executor/models/aria.py | 2 +- vllm/model_executor/models/blip2.py | 2 +- vllm/model_executor/models/chameleon.py | 7 ++++--- vllm/model_executor/models/fuyu.py | 2 +- vllm/model_executor/models/llava.py | 4 ++-- vllm/model_executor/models/llava_next.py | 2 +- vllm/model_executor/models/llava_next_video.py | 2 +- vllm/model_executor/models/llava_onevision.py | 7 ++++--- vllm/model_executor/models/phi3v.py | 2 +- vllm/model_executor/models/qwen2_audio.py | 7 ++++--- vllm/model_executor/models/qwen2_vl.py | 2 +- vllm/model_executor/models/ultravox.py | 3 ++- vllm/multimodal/registry.py | 8 ++++---- 14 files changed, 28 insertions(+), 24 deletions(-) diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py index d07560c2a9b64..ac64edfd4ec9d 100644 --- a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py @@ -12,7 +12,7 @@ @MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor, info=LlavaProcessingInfo, - dummy=LlavaDummyInputsBuilder) + dummy_inputs=LlavaDummyInputsBuilder) class MyLlava(LlavaForConditionalGeneration): def compute_logits( diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index ede88d1a867cb..22ee24e78b4a2 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -526,7 +526,7 @@ def _get_prompt_replacements( @MULTIMODAL_REGISTRY.register_processor(AriaMultiModalProcessor, info=AriaProcessingInfo, - dummy=AriaDummyInputsBuilder) + dummy_inputs=AriaDummyInputsBuilder) class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): """ Aria model for conditional generation tasks. diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index cbdcce6e07b9f..0843fdf4853e4 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -490,7 +490,7 @@ def apply( @MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor, info=Blip2ProcessingInfo, - dummy=Blip2DummyInputsBuilder) + dummy_inputs=Blip2DummyInputsBuilder) class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index bb01b372d17d5..55712cdeb3417 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -914,9 +914,10 @@ def forward( return hidden_states -@MULTIMODAL_REGISTRY.register_processor(ChameleonMultiModalProcessor, - info=ChameleonProcessingInfo, - dummy=ChameleonDummyInputsBuilder) +@MULTIMODAL_REGISTRY.register_processor( + ChameleonMultiModalProcessor, + info=ChameleonProcessingInfo, + dummy_inputs=ChameleonDummyInputsBuilder) class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 66e68e3f8f541..27b2abb9e3745 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -244,7 +244,7 @@ def apply( @MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor, info=FuyuProcessingInfo, - dummy=FuyuDummyInputsBuilder) + dummy_inputs=FuyuDummyInputsBuilder) class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index b42df56d7e589..d6af6a42de1d5 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -461,7 +461,7 @@ def init_vision_tower_for_llava( @MULTIMODAL_REGISTRY.register_processor(_build_llava_or_pixtral_hf_processor, info=_build_llava_or_pixtral_hf_info, - dummy=LlavaDummyInputsBuilder) + dummy_inputs=LlavaDummyInputsBuilder) class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): # BitandBytes specific attributes bitsandbytes_stacked_params_mapping = { @@ -793,6 +793,6 @@ def get_replacement_mantis(item_idx: int): # `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` @MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor, info=LlavaProcessingInfo, - dummy=LlavaDummyInputsBuilder) + dummy_inputs=LlavaDummyInputsBuilder) class MantisForConditionalGeneration(LlavaForConditionalGeneration): pass diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 52560a19a95bb..fda4f22d366b1 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -190,7 +190,7 @@ def _get_mm_fields_config( @MULTIMODAL_REGISTRY.register_processor(LlavaNextMultiModalProcessor, info=LlavaNextProcessingInfo, - dummy=LlavaDummyInputsBuilder) + dummy_inputs=LlavaDummyInputsBuilder) class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index cdb5b49994005..78ff774a792f6 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -272,7 +272,7 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor: @MULTIMODAL_REGISTRY.register_processor( LlavaNextVideoMultiModalProcessor, info=LlavaNextVideoProcessingInfo, - dummy=LlavaNextVideoDummyInputsBuilder, + dummy_inputs=LlavaNextVideoDummyInputsBuilder, ) class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index f3c7843089856..78a47e64d9afc 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -379,9 +379,10 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor: return hidden_states -@MULTIMODAL_REGISTRY.register_processor(LlavaOnevisionMultiModalProcessor, - info=LlavaOnevisionProcessingInfo, - dummy=LlavaOnevisionDummyInputsBuilder) +@MULTIMODAL_REGISTRY.register_processor( + LlavaOnevisionMultiModalProcessor, + info=LlavaOnevisionProcessingInfo, + dummy_inputs=LlavaOnevisionDummyInputsBuilder) class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index dab80c581eccb..bd1f920ea8a81 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -498,7 +498,7 @@ def apply( @MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor, info=Phi3VProcessingInfo, - dummy=Phi3VDummyInputsBuilder) + dummy_inputs=Phi3VDummyInputsBuilder) class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index bc4426eff5866..3271eddfebd23 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -235,9 +235,10 @@ def _always_apply_prompt_replacements(self) -> bool: return not hasattr(self.info.get_hf_processor(), "audio_token") -@MULTIMODAL_REGISTRY.register_processor(Qwen2AudioMultiModalProcessor, - info=Qwen2AudioProcessingInfo, - dummy=Qwen2AudioDummyInputsBuilder) +@MULTIMODAL_REGISTRY.register_processor( + Qwen2AudioMultiModalProcessor, + info=Qwen2AudioProcessingInfo, + dummy_inputs=Qwen2AudioDummyInputsBuilder) class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 0701dfe6d7c0a..cb7a893c7cc90 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -994,7 +994,7 @@ def _get_mm_fields_config( @MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor, info=Qwen2VLProcessingInfo, - dummy=Qwen2VLDummyInputsBuilder) + dummy_inputs=Qwen2VLDummyInputsBuilder) class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP): packed_modules_mapping = { diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 477ed17905589..1d82f552cec68 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -341,7 +341,8 @@ def forward( @MULTIMODAL_REGISTRY.register_processor(UltravoxMultiModalProcessor, info=UltravoxProcessingInfo, - dummy=UltravoxDummyInputsBuilder) + dummy_inputs=UltravoxDummyInputsBuilder + ) class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP): hf_to_vllm_mapper = WeightsMapper( diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 3639384d505f7..074182ec1173c 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -70,7 +70,7 @@ def __call__( class _ProcessorFactories(Generic[_I]): info: ProcessingInfoFactory[_I] processor: MultiModalProcessorFactory[_I] - dummy: DummyInputsBuilderFactory[_I] + dummy_inputs: DummyInputsBuilderFactory[_I] def build_processor( self, @@ -79,7 +79,7 @@ def build_processor( cache: Optional[ProcessingCache] = None, ): info = self.info(ctx) - dummy_inputs_builder = self.dummy(info) + dummy_inputs_builder = self.dummy_inputs(info) return self.processor(info, dummy_inputs_builder, cache=cache) @@ -360,7 +360,7 @@ def register_processor( processor: MultiModalProcessorFactory[_I], *, info: ProcessingInfoFactory[_I], - dummy: DummyInputsBuilderFactory[_I], + dummy_inputs: DummyInputsBuilderFactory[_I], ): """ Register a multi-modal processor to a model class. The processor @@ -383,7 +383,7 @@ def wrapper(model_cls: N) -> N: self._processor_factories[model_cls] = _ProcessorFactories( info=info, - dummy=dummy, + dummy_inputs=dummy_inputs, processor=processor, ) From bad5b080af4dcd2b9d3baaedfc7d1b9f33f0beb2 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 8 Jan 2025 02:45:18 +0000 Subject: [PATCH 5/7] Update profiler Signed-off-by: DarkLight1337 --- vllm/multimodal/profiler.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/vllm/multimodal/profiler.py b/vllm/multimodal/profiler.py index ed3c4edcf092d..d1984c6b0546c 100644 --- a/vllm/multimodal/profiler.py +++ b/vllm/multimodal/profiler.py @@ -16,6 +16,9 @@ class MultiModalProfiler(Generic[_I]): + """ + Contains code for running memory profiling for multi-modal models. + """ def __init__( self, @@ -26,7 +29,7 @@ def __init__( self.processor = processor @property - def processing(self) -> BaseProcessingInfo: + def processing_info(self) -> BaseProcessingInfo: return self.processor.info @property @@ -34,10 +37,10 @@ def dummy_inputs(self) -> BaseDummyInputsBuilder[_I]: return self.processor.dummy_inputs def _get_mm_limits(self) -> Mapping[str, int]: - mm_config = self.processing.ctx.get_mm_config() + mm_config = self.processing_info.ctx.get_mm_config() mm_limit_per_prompt = mm_config.limit_per_prompt - supported_mm_limits = self.processing.get_supported_mm_limits() + supported_mm_limits = self.processing_info.get_supported_mm_limits() mm_limits = { modality: mm_limit_per_prompt.get(modality, 1) @@ -75,8 +78,8 @@ def get_dummy_data(self, seq_len: int) -> DummyData: mm_counts = self._get_mm_limits() - processing = self.processing - mm_max_tokens_per_item = processing.get_mm_max_tokens_per_item(seq_len) + info = self.processing_info + mm_max_tokens_per_item = info.get_mm_max_tokens_per_item(seq_len) if mm_counts.keys() != mm_max_tokens_per_item.keys(): raise AssertionError( From 213ce333aa9901eac5736870f28d0fad8e118cbc Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 8 Jan 2025 02:54:21 +0000 Subject: [PATCH 6/7] Combine files Signed-off-by: DarkLight1337 --- .../processing/test_llava_next.py | 2 +- .../processing/test_llava_onevision.py | 2 +- tests/multimodal/test_processing.py | 2 +- vllm/inputs/registry.py | 2 +- vllm/model_executor/models/aria.py | 4 +- vllm/model_executor/models/blip2.py | 4 +- vllm/model_executor/models/chameleon.py | 4 +- vllm/model_executor/models/fuyu.py | 4 +- vllm/model_executor/models/llava.py | 4 +- .../model_executor/models/llava_next_video.py | 4 +- vllm/model_executor/models/phi3v.py | 4 +- vllm/model_executor/models/qwen2_audio.py | 4 +- vllm/model_executor/models/qwen2_vl.py | 4 +- vllm/model_executor/models/ultravox.py | 4 +- vllm/multimodal/processing.py | 557 ++++++++++++++++- vllm/multimodal/processor.py | 562 ------------------ vllm/multimodal/profiler.py | 136 ----- vllm/multimodal/profiling.py | 131 +++- vllm/multimodal/registry.py | 4 +- 19 files changed, 709 insertions(+), 729 deletions(-) delete mode 100644 vllm/multimodal/processor.py delete mode 100644 vllm/multimodal/profiler.py diff --git a/tests/models/decoder_only/vision_language/processing/test_llava_next.py b/tests/models/decoder_only/vision_language/processing/test_llava_next.py index 737a8c8c78d76..689d17be81889 100644 --- a/tests/models/decoder_only/vision_language/processing/test_llava_next.py +++ b/tests/models/decoder_only/vision_language/processing/test_llava_next.py @@ -7,7 +7,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.parse import ImageSize -from vllm.multimodal.processor import BaseMultiModalProcessor +from vllm.multimodal.processing import BaseMultiModalProcessor from vllm.multimodal.utils import cached_get_tokenizer from ....utils import build_model_context diff --git a/tests/models/decoder_only/vision_language/processing/test_llava_onevision.py b/tests/models/decoder_only/vision_language/processing/test_llava_onevision.py index 21765932be2a6..a033354f0e9b8 100644 --- a/tests/models/decoder_only/vision_language/processing/test_llava_onevision.py +++ b/tests/models/decoder_only/vision_language/processing/test_llava_onevision.py @@ -7,7 +7,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.parse import ImageSize -from vllm.multimodal.processor import BaseMultiModalProcessor +from vllm.multimodal.processing import BaseMultiModalProcessor from vllm.multimodal.utils import cached_get_tokenizer from ....utils import build_model_context diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index 808e82e05cfcf..d98bd9736b65f 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -20,7 +20,7 @@ replace_text_matches, replace_token_matches) # yapf: enable -from vllm.multimodal.profiler import MultiModalProfiler +from vllm.multimodal.profiling import MultiModalProfiler from vllm.multimodal.utils import cached_get_tokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import full_groupby diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 3fa9a8c14842d..b22b3f1594f24 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -323,7 +323,7 @@ def dummy_data_for_profiling( # Avoid circular import from vllm.model_executor.model_loader import get_model_architecture from vllm.multimodal import MultiModalKwargs - from vllm.multimodal.profiler import MultiModalProfiler + from vllm.multimodal.profiling import MultiModalProfiler from vllm.multimodal.utils import cached_get_tokenizer if mm_registry.has_processor(model_config): diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index 22ee24e78b4a2..089062ab53fc3 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -24,8 +24,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, NestedTensors) from vllm.multimodal.parse import MultiModalDataItems -from vllm.multimodal.processing import BaseProcessingInfo, PromptReplacement -from vllm.multimodal.processor import BaseMultiModalProcessor +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.aria import (AriaMoELMConfig, diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 0843fdf4853e4..7dfc0b687c6e3 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -18,8 +18,8 @@ MultiModalInputsV2, MultiModalKwargs, NestedTensors, PlaceholderRange) from vllm.multimodal.parse import MultiModalDataItems -from vllm.multimodal.processing import BaseProcessingInfo, PromptReplacement -from vllm.multimodal.processor import BaseMultiModalProcessor +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 55712cdeb3417..acff926891bbe 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -31,8 +31,8 @@ MultiModalInputsV2, MultiModalKwargs, NestedTensors, PlaceholderRange) from vllm.multimodal.parse import MultiModalDataItems -from vllm.multimodal.processing import BaseProcessingInfo, PromptReplacement -from vllm.multimodal.processor import BaseMultiModalProcessor +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.utils import print_warning_once diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 27b2abb9e3745..59af5f0b3ae98 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -35,8 +35,8 @@ NestedTensors, PlaceholderRange) from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) -from vllm.multimodal.processing import BaseProcessingInfo, PromptReplacement -from vllm.multimodal.processor import BaseMultiModalProcessor +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index d6af6a42de1d5..66e75a16c91f5 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -26,9 +26,9 @@ NestedTensors) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) -from vllm.multimodal.processing import (BaseProcessingInfo, ProcessingCache, +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, ProcessingCache, PromptReplacement) -from vllm.multimodal.processor import BaseMultiModalProcessor from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index 78ff774a792f6..5be85d7c0f033 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -19,8 +19,8 @@ NestedTensors) from vllm.multimodal.parse import (ImageSize, MultiModalDataItems, VideoEmbeddingItems, VideoProcessorItems) -from vllm.multimodal.processing import BaseProcessingInfo, PromptReplacement -from vllm.multimodal.processor import BaseMultiModalProcessor +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index bd1f920ea8a81..624f805bb2844 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -35,10 +35,10 @@ NestedTensors, PlaceholderRange) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) -from vllm.multimodal.processing import (BaseProcessingInfo, +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, BoundPromptReplacement, PlaceholderInfo, PromptReplacement) -from vllm.multimodal.processor import BaseMultiModalProcessor from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 3271eddfebd23..0dff9595c6c08 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -40,8 +40,8 @@ NestedTensors) from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems, MultiModalDataParser) -from vllm.multimodal.processing import BaseProcessingInfo, PromptReplacement -from vllm.multimodal.processor import BaseMultiModalProcessor +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index cb7a893c7cc90..ba7e8f838d0fe 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -58,8 +58,8 @@ NestedTensors, VideoItem) from vllm.multimodal.parse import (ImageSize, ModalityDataItems, MultiModalDataItems, MultiModalDataParser) -from vllm.multimodal.processing import BaseProcessingInfo, PromptReplacement -from vllm.multimodal.processor import BaseMultiModalProcessor +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 1d82f552cec68..fada22d685dd6 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -25,8 +25,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, NestedTensors) from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser -from vllm.multimodal.processing import BaseProcessingInfo, PromptReplacement -from vllm.multimodal.processor import BaseMultiModalProcessor +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.ultravox import UltravoxConfig diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 5571f9fbc61b7..c6a30cacebdd1 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -4,10 +4,12 @@ from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence from dataclasses import dataclass, field from functools import lru_cache -from typing import NamedTuple, Optional, Protocol, TypeVar, Union +from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol, + TypeVar, Union) -from transformers import PretrainedConfig, ProcessorMixin +from transformers import BatchFeature, PretrainedConfig, ProcessorMixin +import vllm.envs as envs from vllm.inputs import InputProcessingContext from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens, @@ -15,7 +17,13 @@ from vllm.utils import LRUCache, flatten_2d_lists, full_groupby from .hasher import MultiModalHasher -from .inputs import MultiModalKwargsItem, PlaceholderRange +from .inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalInputsV2, MultiModalKwargs, + MultiModalKwargsItem, PlaceholderRange) +from .parse import MultiModalDataItems, MultiModalDataParser + +if TYPE_CHECKING: + from .profiling import BaseDummyInputsBuilder logger = init_logger(__name__) @@ -566,3 +574,546 @@ def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: keys as that returned by :meth:`get_supported_mm_limits`. """ raise NotImplementedError + + +_I = TypeVar("_I", bound=BaseProcessingInfo) + + +class BaseMultiModalProcessor(ABC, Generic[_I]): + """ + Abstract base class to process multi-modal inputs to be used in vLLM. + + Not to be confused with :class:`transformers.ProcessorMixin`. + """ + + def __init__(self, + info: _I, + dummy_inputs: "BaseDummyInputsBuilder[_I]", + *, + cache: Optional[ProcessingCache] = None, + enable_sanity_checks: bool = True) -> None: + super().__init__() + + self.info = info + self.dummy_inputs = dummy_inputs + self.cache = cache + self.enable_sanity_checks = enable_sanity_checks + + self.data_parser = self._get_data_parser() + + def __call__( + self, + prompt: str, + mm_data: MultiModalDataDict, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> MultiModalInputsV2: + return self.apply(prompt, mm_data, hf_processor_mm_kwargs) + + def _get_data_parser(self) -> MultiModalDataParser: + """ + Construct a parser to preprocess multi-modal data items + before passing them to :meth:`_get_hf_mm_data`. + + You can support additional modalities by creating a subclass + of :class:`MultiModalDataParser` that has additional subparsers. + """ + return MultiModalDataParser() + + def _to_mm_items( + self, + mm_data: MultiModalDataDict, + ) -> MultiModalDataItems: + """ + Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems` + before passing them to :meth:`_get_hf_mm_data`. + """ + mm_items = self.data_parser.parse_mm_data(mm_data) + + mm_limits = self.info.ctx.get_mm_config().limit_per_prompt + for modality, items in mm_items.items(): + limit = mm_limits.get(modality, 1) + if len(items) > limit: + raise ValueError( + f"You set {modality}={limit} (or defaulted to 1) in " + f"`--limit-mm-per-prompt`, but passed {len(items)} " + f"{modality} items in the same prompt.") + + return mm_items + + @abstractmethod + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + """Given the HF-processed data, output the metadata of each field.""" + raise NotImplementedError + + @abstractmethod + def _get_prompt_replacements( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> list[PromptReplacement]: + """ + Given the original multi-modal items for this modality + and HF-processed data, output the replacements to perform. + + Notes: + - You should not assume that HF processor always performs prompt + replacement: in :meth:`_apply_hf_processor_missing`, this method + is called on text-only and multimodal-only inputs separately, + instead of passing them in the same call. + - The replacement information returned by this method is also used + to determine the placeholder token positions for each multi-modal + item. + """ + raise NotImplementedError + + def _find_mm_placeholders( + self, + mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]], + new_token_ids: list[int], + mm_item_counts: Mapping[str, int], + ) -> Mapping[str, list[PlaceholderInfo]]: + return find_mm_placeholders(mm_prompt_repls, new_token_ids, + mm_item_counts) + + def _get_hf_mm_data( + self, + mm_items: MultiModalDataItems, + ) -> tuple[Mapping[str, object], Mapping[str, object]]: + processor_data = dict[str, object]() + passthrough_data = dict[str, object]() + + for items in mm_items.values(): + processor_data.update(items.get_processor_data()) + passthrough_data.update(items.get_passthrough_data()) + + return processor_data, passthrough_data + + def _call_hf_processor( + self, + prompt: str, + # Not to be confused with `mm_data` in `self.apply`. + # This refers to the data to be passed to HF processor. + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + """ + Call the HF processor on the prompt text and + associated multi-modal data. + """ + return self.info.ctx.call_hf_processor( + self.info.get_hf_processor(**mm_kwargs), + dict(text=prompt, **mm_data), + mm_kwargs, + ) + + def _apply_hf_processor( + self, + prompt_text: str, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> tuple[list[int], MultiModalKwargs]: + """ + Wrapper of :meth:`_call_hf_processor` that applies + additional pre-processing and post-processing. + """ + processor_data, passthrough_data = self._get_hf_mm_data(mm_items) + + processed_data = self._call_hf_processor( + prompt=prompt_text, + mm_data=processor_data, + mm_kwargs=hf_processor_mm_kwargs, + ) + processed_data.update(passthrough_data) + + prompt_ids, = processed_data.pop("input_ids").tolist() + + mm_kwargs = MultiModalKwargs.from_hf_inputs( + processed_data, + self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs), + ) + + return prompt_ids, mm_kwargs + + def _apply_hf_processor_missing( + self, + prompt_text: str, + mm_missing_data_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + ): + """ + Apply the HF processor on the full prompt text, but only on the + multi-modal data that are missing from the cache. + + Note: + We pass prompt text and multi-modal data into the HF processor + in separate calls to avoid HF prompt replacement being done for + cached items; instead, we rely on our own prompt replacement logic + (:meth:`_get_prompt_replacements`) for the full text. + """ + mm_missing_counts = mm_missing_data_items.get_all_counts() + + prompt_ids, _ = self._apply_hf_processor( + prompt_text=prompt_text, + mm_items=MultiModalDataItems({}), + hf_processor_mm_kwargs={}, + ) + + # Some HF processors (e.g. Qwen2-VL) expect corresponding + # multi-modal tokens to be in the prompt text + dummy_inputs = self.dummy_inputs.get_dummy_processor_inputs( + self.info.ctx.model_config.max_model_len, + mm_missing_counts, + ) + + _, mm_missing_kwargs = self._apply_hf_processor( + prompt_text=dummy_inputs.prompt_text, + mm_items=mm_missing_data_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + ) + + return prompt_ids, mm_missing_kwargs + + def _cached_apply_hf_processor( + self, + prompt_text: str, + mm_data_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> tuple[list[int], MultiModalKwargs]: + """ + Apply the HF processor on the full prompt text, + caching the results and reusing cached results. + """ + cache = self.cache + model_id = self.info.model_id + + _, passthrough_data = self._get_hf_mm_data(mm_data_items) + if cache is None or passthrough_data: + return self._apply_hf_processor( + prompt_text=prompt_text, + mm_items=mm_data_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + ) + + mm_maybe_cached_kw_items = { + modality: [ + cache.get(model_id, modality, item, hf_processor_mm_kwargs) + for item in items + ] + for modality, items in mm_data_items.items() + } + + mm_missing_idxs = { + modality: + [idx for idx, item in enumerate(kw_items) if item is None] + for modality, kw_items in mm_maybe_cached_kw_items.items() + } + mm_missing_data = { + modality: [mm_data_items[modality][idx] for idx in idxs] + for modality, idxs in mm_missing_idxs.items() + } + mm_missing_data_items = self._to_mm_items(mm_missing_data) + + prompt_ids, mm_missing_kwargs = self._apply_hf_processor_missing( + prompt_text=prompt_text, + mm_missing_data_items=mm_missing_data_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + ) + + mm_missing_next_idx = { + modality: 0 + for modality in mm_missing_data_items + } + + merged_kw_items = list[MultiModalKwargsItem]() + for modality, kw_items in mm_maybe_cached_kw_items.items(): + for idx, kw_item in enumerate(kw_items): + if kw_item is None: + kw_item = mm_missing_kwargs.get_item( + modality, + mm_missing_next_idx[modality], + ) + + cache.put( + model_id, + modality, + mm_data_items[modality][idx], + hf_processor_mm_kwargs, + kw_item, + ) + + mm_missing_next_idx[modality] += 1 + + merged_kw_items.append(kw_item) + + if self.enable_sanity_checks: + mm_missing_counts = mm_missing_data_items.get_all_counts() + assert all( + item_count == mm_missing_counts[modality] + for modality, item_count in mm_missing_next_idx.items()), dict( + mm_missing_next_idx=mm_missing_next_idx, + mm_missing_counts=mm_missing_counts) + + mm_kwargs = MultiModalKwargs.from_items(merged_kw_items) + + return prompt_ids, mm_kwargs + + def _bind_and_group_repls( + self, + prompt_repls: list[PromptReplacement], + ) -> dict[str, list[BoundPromptReplacement]]: + tokenizer = self.info.get_tokenizer() + + it = (prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls) + return dict(full_groupby_modality(it)) + + def _always_apply_prompt_replacements(self) -> bool: + """ + A flag which can be overridden so that + :meth:`_apply_prompt_replacements` is always called even if we + detect that HF has performed processing via + :meth:`_find_placeholders_by_modality`. + + This is useful in cases where :meth:`_find_placeholders_by_modality` + cannot be reliably used to detect whether HF has performed processing. + """ + return False + + def _apply_prompt_replacements( + self, + token_ids: list[int], + mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]], + mm_item_counts: Mapping[str, int], + ) -> tuple[list[int], str, Mapping[str, list[PlaceholderInfo]]]: + tokenizer = self.info.get_tokenizer() + + mm_token_matches = { + modality: find_token_matches(token_ids, prompt_repls) + for modality, prompt_repls in mm_prompt_repls.items() + } + mm_match_counts = { + modality: len(matches) + for modality, matches in mm_token_matches.items() + } + + # If the search text does not represent a special token, + # it may have different token IDs in the prompt, because + # the tokens may go across the boundaries of the search text. + # ---- + # e.g. when searching for "foo" in "food", if "food" itself makes + # up a token, then the token ID of "foo" will not appear at all + # ---- + # Since it is inefficient to search for all possible tokenizations + # of the search text in the prompt, we instead perform string + # replacement on the decoded token IDs, then encode them back. + if all( + mm_match_counts.get(modality, 0) >= item_count + for modality, item_count in mm_item_counts.items() + ): # yapf: disable + token_ids = replace_token_matches( + token_ids, + mm_token_matches, + mm_item_counts, + ) + + text = decode_tokens(tokenizer, token_ids) + matched_repls = { + modality: [match.prompt_repl for match in token_matches] + for modality, token_matches in mm_token_matches.items() + } + else: + text = decode_tokens(tokenizer, token_ids) + + mm_text_matches = { + modality: find_text_matches(text, prompt_repls) + for modality, prompt_repls in mm_prompt_repls.items() + } + text = replace_text_matches( + text, + mm_text_matches, + mm_item_counts, + ) + + token_ids = encode_tokens(tokenizer, + text, + add_special_tokens=False) + matched_repls = { + modality: [match.prompt_repl for match in token_matches] + for modality, token_matches in mm_text_matches.items() + } + + placeholders = self._find_mm_placeholders( + matched_repls, + token_ids, + mm_item_counts, + ) + + return token_ids, text, placeholders + + def _validate_mm_kwargs( + self, + mm_kwargs: MultiModalKwargs, + mm_item_counts: Mapping[str, int], + ) -> None: + for modality, item_count in mm_item_counts.items(): + if modality in mm_kwargs.modalities: + items = mm_kwargs.get_items(modality) + else: + items = [] + + if len(items) != item_count: + raise RuntimeError( + f"Expected there to be {item_count} {modality} items in " + f"keyword arguments corresponding to {item_count} " + f"{modality} data items, but only found {len(items)}! " + "There is likely a problem with your " + "implementation of merged multi-modal processor for this " + "model (usually arising from an inconsistency between " + "`_call_hf_processor` and `_get_mm_fields_config`).") + + def _validate_mm_placeholders( + self, + mm_placeholders: Mapping[str, list[PlaceholderInfo]], + mm_item_counts: Mapping[str, int], + *, + allow_missing: bool = False, + ) -> Mapping[str, int]: + missing_repl_counts = dict[str, int]() + + for modality, item_count in mm_item_counts.items(): + placeholders = mm_placeholders.get(modality, []) + + if len(placeholders) != item_count and not allow_missing: + raise RuntimeError( + f"Expected there to be {item_count} prompt replacements " + f"corresponding to {item_count} {modality} items, but only " + f"found {len(placeholders)} prompt replacements! Either " + "the prompt text has missing/incorrect tokens for " + "multi-modal inputs, or there is a problem with your " + "implementation of merged multi-modal processor for this " + "model (usually arising from an inconsistency between " + "`_call_hf_processor` and `_get_prompt_replacements`).") + + missing_repl_counts[modality] = item_count - len(placeholders) + + return missing_repl_counts + + def apply( + self, + prompt_text: str, + mm_data: MultiModalDataDict, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> MultiModalInputsV2: + """ + Process multi-modal inputs to be used in vLLM. + + The main steps are: + + 1. Apply HF Processor on prompt text and multi-modal data together, + outputting token IDs and processed tensors. + 2. Find and replace sequences in the token IDs with placeholder tokens. + The number of placeholder tokens equals the feature size of the + multi-modal data outputted by the multi-modal encoder. + 3. Extract information about the placeholder tokens from the + processed token IDs. + """ + mm_items = self._to_mm_items(mm_data) + + # Create MM hashes (only used in V1) + # TODO: Use these hash keys for caching operations in apply_hf_processor + # instead of rehashing. + + if envs.VLLM_USE_V1: + model_id = self.info.model_id + mm_hashes = { + modality: [ + MultiModalHasher.hash_kwargs(model_id=model_id, + **{modality: item}, + **hf_processor_mm_kwargs) + for item in items + ] + for modality, items in mm_items.items() + } + else: + mm_hashes = None + + prompt_ids, mm_kwargs = self._cached_apply_hf_processor( + prompt_text, + mm_items, + hf_processor_mm_kwargs, + ) + + unbound_prompt_repls = self._get_prompt_replacements( + mm_items, + hf_processor_mm_kwargs, + mm_kwargs, + ) + mm_prompt_repls = self._bind_and_group_repls(unbound_prompt_repls) + + mm_item_counts = mm_items.get_all_counts() + self._validate_mm_kwargs(mm_kwargs, mm_item_counts) + + hf_mm_placeholders = self._find_mm_placeholders( + mm_prompt_repls, + prompt_ids, + mm_item_counts, + ) + + if self._always_apply_prompt_replacements(): + mm_missing_repl_counts = mm_item_counts + mm_missing_repls = dict(mm_prompt_repls) + else: + mm_missing_repl_counts = self._validate_mm_placeholders( + hf_mm_placeholders, + mm_item_counts, + allow_missing=True, + ) + + mm_missing_repls = dict[str, list[BoundPromptReplacement]]() + for modality, missing_repl_count in mm_missing_repl_counts.items(): + if missing_repl_count == 0: + mm_missing_repls[modality] = [] + elif missing_repl_count == mm_item_counts.get(modality, 0): + mm_missing_repls[modality] = mm_prompt_repls[modality] + else: + raise ValueError("Partial prompt replacement within " + f"{modality=} is not supported") + + # If HF processor already inserts placeholder tokens, + # there is no need for us to insert them + if all(len(repls) == 0 for repls in mm_missing_repls.items()): + tokenizer = self.info.get_tokenizer() + prompt_text = decode_tokens(tokenizer, prompt_ids) + mm_placeholders = hf_mm_placeholders + else: + ( + prompt_ids, + prompt_text, + missing_mm_placeholders, + ) = self._apply_prompt_replacements( + prompt_ids, + mm_missing_repls, + mm_missing_repl_counts, + ) + + mm_placeholders = {**hf_mm_placeholders, **missing_mm_placeholders} + + self._validate_mm_placeholders(mm_placeholders, mm_item_counts) + + mm_placeholder_ranges = { + modality: [item.to_range() for item in placeholders] + for modality, placeholders in mm_placeholders.items() + } + + return MultiModalInputsV2( + type="multimodal", + prompt=prompt_text, + prompt_token_ids=prompt_ids, + mm_kwargs=mm_kwargs, + mm_hashes=mm_hashes, + mm_placeholders=mm_placeholder_ranges, + ) diff --git a/vllm/multimodal/processor.py b/vllm/multimodal/processor.py deleted file mode 100644 index 08e6b7337e62d..0000000000000 --- a/vllm/multimodal/processor.py +++ /dev/null @@ -1,562 +0,0 @@ -from abc import ABC, abstractmethod -from collections.abc import Mapping, Sequence -from typing import Any, Generic, Optional, TypeVar - -from transformers import BatchFeature - -from vllm import envs -from vllm.transformers_utils.tokenizer import decode_tokens, encode_tokens - -from .hasher import MultiModalHasher -from .inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalInputsV2, MultiModalKwargs, - MultiModalKwargsItem) -from .parse import MultiModalDataItems, MultiModalDataParser -from .processing import (BaseProcessingInfo, BoundPromptReplacement, - PlaceholderInfo, ProcessingCache, PromptReplacement, - find_mm_placeholders, find_text_matches, - find_token_matches, full_groupby_modality, - replace_text_matches, replace_token_matches) -from .profiling import BaseDummyInputsBuilder - -_I = TypeVar("_I", bound=BaseProcessingInfo) - - -class BaseMultiModalProcessor(ABC, Generic[_I]): - """ - Abstract base class to process multi-modal inputs to be used in vLLM. - - Not to be confused with :class:`transformers.ProcessorMixin`. - """ - - def __init__(self, - info: _I, - dummy_inputs: BaseDummyInputsBuilder[_I], - *, - cache: Optional[ProcessingCache] = None, - enable_sanity_checks: bool = True) -> None: - super().__init__() - - self.info = info - self.dummy_inputs = dummy_inputs - self.cache = cache - self.enable_sanity_checks = enable_sanity_checks - - self.data_parser = self._get_data_parser() - - def __call__( - self, - prompt: str, - mm_data: MultiModalDataDict, - hf_processor_mm_kwargs: Mapping[str, object], - ) -> MultiModalInputsV2: - return self.apply(prompt, mm_data, hf_processor_mm_kwargs) - - def _get_data_parser(self) -> MultiModalDataParser: - """ - Construct a parser to preprocess multi-modal data items - before passing them to :meth:`_get_hf_mm_data`. - - You can support additional modalities by creating a subclass - of :class:`MultiModalDataParser` that has additional subparsers. - """ - return MultiModalDataParser() - - def _to_mm_items( - self, - mm_data: MultiModalDataDict, - ) -> MultiModalDataItems: - """ - Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems` - before passing them to :meth:`_get_hf_mm_data`. - """ - mm_items = self.data_parser.parse_mm_data(mm_data) - - mm_limits = self.info.ctx.get_mm_config().limit_per_prompt - for modality, items in mm_items.items(): - limit = mm_limits.get(modality, 1) - if len(items) > limit: - raise ValueError( - f"You set {modality}={limit} (or defaulted to 1) in " - f"`--limit-mm-per-prompt`, but passed {len(items)} " - f"{modality} items in the same prompt.") - - return mm_items - - @abstractmethod - def _get_mm_fields_config( - self, - hf_inputs: BatchFeature, - hf_processor_mm_kwargs: Mapping[str, object], - ) -> Mapping[str, MultiModalFieldConfig]: - """Given the HF-processed data, output the metadata of each field.""" - raise NotImplementedError - - @abstractmethod - def _get_prompt_replacements( - self, - mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, - ) -> list[PromptReplacement]: - """ - Given the original multi-modal items for this modality - and HF-processed data, output the replacements to perform. - - Notes: - - You should not assume that HF processor always performs prompt - replacement: in :meth:`_apply_hf_processor_missing`, this method - is called on text-only and multimodal-only inputs separately, - instead of passing them in the same call. - - The replacement information returned by this method is also used - to determine the placeholder token positions for each multi-modal - item. - """ - raise NotImplementedError - - def _find_mm_placeholders( - self, - mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]], - new_token_ids: list[int], - mm_item_counts: Mapping[str, int], - ) -> Mapping[str, list[PlaceholderInfo]]: - return find_mm_placeholders(mm_prompt_repls, new_token_ids, - mm_item_counts) - - def _get_hf_mm_data( - self, - mm_items: MultiModalDataItems, - ) -> tuple[dict[str, Any], dict[str, Any]]: - processor_data = dict[str, Any]() - passthrough_data = dict[str, Any]() - - for items in mm_items.values(): - processor_data.update(items.get_processor_data()) - passthrough_data.update(items.get_passthrough_data()) - - return processor_data, passthrough_data - - def _call_hf_processor( - self, - prompt: str, - # Not to be confused with `mm_data` in `self.apply`. - # This refers to the data to be passed to HF processor. - mm_data: Mapping[str, object], - mm_kwargs: Mapping[str, object], - ) -> BatchFeature: - """ - Call the HF processor on the prompt text and - associated multi-modal data. - """ - return self.info.ctx.call_hf_processor( - self.info.get_hf_processor(**mm_kwargs), - dict(text=prompt, **mm_data), - mm_kwargs, - ) - - def _apply_hf_processor( - self, - prompt_text: str, - mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], - ) -> tuple[list[int], MultiModalKwargs]: - """ - Wrapper of :meth:`_call_hf_processor` that applies - additional pre-processing and post-processing. - """ - processor_data, passthrough_data = self._get_hf_mm_data(mm_items) - - processed_data = self._call_hf_processor( - prompt=prompt_text, - mm_data=processor_data, - mm_kwargs=hf_processor_mm_kwargs, - ) - processed_data.update(passthrough_data) - - prompt_ids, = processed_data.pop("input_ids").tolist() - - mm_kwargs = MultiModalKwargs.from_hf_inputs( - processed_data, - self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs), - ) - - return prompt_ids, mm_kwargs - - def _apply_hf_processor_missing( - self, - prompt_text: str, - mm_missing_data_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], - ): - """ - Apply the HF processor on the full prompt text, but only on the - multi-modal data that are missing from the cache. - - Note: - We pass prompt text and multi-modal data into the HF processor - in separate calls to avoid HF prompt replacement being done for - cached items; instead, we rely on our own prompt replacement logic - (:meth:`_get_prompt_replacements`) for the full text. - """ - mm_missing_counts = mm_missing_data_items.get_all_counts() - - prompt_ids, _ = self._apply_hf_processor( - prompt_text=prompt_text, - mm_items=MultiModalDataItems({}), - hf_processor_mm_kwargs={}, - ) - - # Some HF processors (e.g. Qwen2-VL) expect corresponding - # multi-modal tokens to be in the prompt text - dummy_inputs = self.dummy_inputs.get_dummy_processor_inputs( - self.info.ctx.model_config.max_model_len, - mm_missing_counts, - ) - - _, mm_missing_kwargs = self._apply_hf_processor( - prompt_text=dummy_inputs.prompt_text, - mm_items=mm_missing_data_items, - hf_processor_mm_kwargs=hf_processor_mm_kwargs, - ) - - return prompt_ids, mm_missing_kwargs - - def _cached_apply_hf_processor( - self, - prompt_text: str, - mm_data_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], - ) -> tuple[list[int], MultiModalKwargs]: - """ - Apply the HF processor on the full prompt text, - caching the results and reusing cached results. - """ - cache = self.cache - model_id = self.info.model_id - - _, passthrough_data = self._get_hf_mm_data(mm_data_items) - if cache is None or passthrough_data: - return self._apply_hf_processor( - prompt_text=prompt_text, - mm_items=mm_data_items, - hf_processor_mm_kwargs=hf_processor_mm_kwargs, - ) - - mm_maybe_cached_kw_items = { - modality: [ - cache.get(model_id, modality, item, hf_processor_mm_kwargs) - for item in items - ] - for modality, items in mm_data_items.items() - } - - mm_missing_idxs = { - modality: - [idx for idx, item in enumerate(kw_items) if item is None] - for modality, kw_items in mm_maybe_cached_kw_items.items() - } - mm_missing_data = { - modality: [mm_data_items[modality][idx] for idx in idxs] - for modality, idxs in mm_missing_idxs.items() - } - mm_missing_data_items = self._to_mm_items(mm_missing_data) - - prompt_ids, mm_missing_kwargs = self._apply_hf_processor_missing( - prompt_text=prompt_text, - mm_missing_data_items=mm_missing_data_items, - hf_processor_mm_kwargs=hf_processor_mm_kwargs, - ) - - mm_missing_next_idx = { - modality: 0 - for modality in mm_missing_data_items - } - - merged_kw_items = list[MultiModalKwargsItem]() - for modality, kw_items in mm_maybe_cached_kw_items.items(): - for idx, kw_item in enumerate(kw_items): - if kw_item is None: - kw_item = mm_missing_kwargs.get_item( - modality, - mm_missing_next_idx[modality], - ) - - cache.put( - model_id, - modality, - mm_data_items[modality][idx], - hf_processor_mm_kwargs, - kw_item, - ) - - mm_missing_next_idx[modality] += 1 - - merged_kw_items.append(kw_item) - - if self.enable_sanity_checks: - mm_missing_counts = mm_missing_data_items.get_all_counts() - assert all( - item_count == mm_missing_counts[modality] - for modality, item_count in mm_missing_next_idx.items()), dict( - mm_missing_next_idx=mm_missing_next_idx, - mm_missing_counts=mm_missing_counts) - - mm_kwargs = MultiModalKwargs.from_items(merged_kw_items) - - return prompt_ids, mm_kwargs - - def _bind_and_group_repls( - self, - prompt_repls: list[PromptReplacement], - ) -> dict[str, list[BoundPromptReplacement]]: - tokenizer = self.info.get_tokenizer() - - it = (prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls) - return dict(full_groupby_modality(it)) - - def _always_apply_prompt_replacements(self) -> bool: - """ - A flag which can be overridden so that - :meth:`_apply_prompt_replacements` is always called even if we - detect that HF has performed processing via - :meth:`_find_placeholders_by_modality`. - - This is useful in cases where :meth:`_find_placeholders_by_modality` - cannot be reliably used to detect whether HF has performed processing. - """ - return False - - def _apply_prompt_replacements( - self, - token_ids: list[int], - mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]], - mm_item_counts: Mapping[str, int], - ) -> tuple[list[int], str, Mapping[str, list[PlaceholderInfo]]]: - tokenizer = self.info.get_tokenizer() - - mm_token_matches = { - modality: find_token_matches(token_ids, prompt_repls) - for modality, prompt_repls in mm_prompt_repls.items() - } - mm_match_counts = { - modality: len(matches) - for modality, matches in mm_token_matches.items() - } - - # If the search text does not represent a special token, - # it may have different token IDs in the prompt, because - # the tokens may go across the boundaries of the search text. - # ---- - # e.g. when searching for "foo" in "food", if "food" itself makes - # up a token, then the token ID of "foo" will not appear at all - # ---- - # Since it is inefficient to search for all possible tokenizations - # of the search text in the prompt, we instead perform string - # replacement on the decoded token IDs, then encode them back. - if all( - mm_match_counts.get(modality, 0) >= item_count - for modality, item_count in mm_item_counts.items() - ): # yapf: disable - token_ids = replace_token_matches( - token_ids, - mm_token_matches, - mm_item_counts, - ) - - text = decode_tokens(tokenizer, token_ids) - matched_repls = { - modality: [match.prompt_repl for match in token_matches] - for modality, token_matches in mm_token_matches.items() - } - else: - text = decode_tokens(tokenizer, token_ids) - - mm_text_matches = { - modality: find_text_matches(text, prompt_repls) - for modality, prompt_repls in mm_prompt_repls.items() - } - text = replace_text_matches( - text, - mm_text_matches, - mm_item_counts, - ) - - token_ids = encode_tokens(tokenizer, - text, - add_special_tokens=False) - matched_repls = { - modality: [match.prompt_repl for match in token_matches] - for modality, token_matches in mm_text_matches.items() - } - - placeholders = self._find_mm_placeholders( - matched_repls, - token_ids, - mm_item_counts, - ) - - return token_ids, text, placeholders - - def _validate_mm_kwargs( - self, - mm_kwargs: MultiModalKwargs, - mm_item_counts: Mapping[str, int], - ) -> None: - for modality, item_count in mm_item_counts.items(): - if modality in mm_kwargs.modalities: - items = mm_kwargs.get_items(modality) - else: - items = [] - - if len(items) != item_count: - raise RuntimeError( - f"Expected there to be {item_count} {modality} items in " - f"keyword arguments corresponding to {item_count} " - f"{modality} data items, but only found {len(items)}! " - "There is likely a problem with your " - "implementation of merged multi-modal processor for this " - "model (usually arising from an inconsistency between " - "`_call_hf_processor` and `_get_mm_fields_config`).") - - def _validate_mm_placeholders( - self, - mm_placeholders: Mapping[str, list[PlaceholderInfo]], - mm_item_counts: Mapping[str, int], - *, - allow_missing: bool = False, - ) -> Mapping[str, int]: - missing_repl_counts = dict[str, int]() - - for modality, item_count in mm_item_counts.items(): - placeholders = mm_placeholders.get(modality, []) - - if len(placeholders) != item_count and not allow_missing: - raise RuntimeError( - f"Expected there to be {item_count} prompt replacements " - f"corresponding to {item_count} {modality} items, but only " - f"found {len(placeholders)} prompt replacements! Either " - "the prompt text has missing/incorrect tokens for " - "multi-modal inputs, or there is a problem with your " - "implementation of merged multi-modal processor for this " - "model (usually arising from an inconsistency between " - "`_call_hf_processor` and `_get_prompt_replacements`).") - - missing_repl_counts[modality] = item_count - len(placeholders) - - return missing_repl_counts - - def apply( - self, - prompt_text: str, - mm_data: MultiModalDataDict, - hf_processor_mm_kwargs: Mapping[str, object], - ) -> MultiModalInputsV2: - """ - Process multi-modal inputs to be used in vLLM. - - The main steps are: - - 1. Apply HF Processor on prompt text and multi-modal data together, - outputting token IDs and processed tensors. - 2. Find and replace sequences in the token IDs with placeholder tokens. - The number of placeholder tokens equals the feature size of the - multi-modal data outputted by the multi-modal encoder. - 3. Extract information about the placeholder tokens from the - processed token IDs. - """ - mm_items = self._to_mm_items(mm_data) - - # Create MM hashes (only used in V1) - # TODO: Use these hash keys for caching operations in apply_hf_processor - # instead of rehashing. - - if envs.VLLM_USE_V1: - model_id = self.info.model_id - mm_hashes = { - modality: [ - MultiModalHasher.hash_kwargs(model_id=model_id, - **{modality: item}, - **hf_processor_mm_kwargs) - for item in items - ] - for modality, items in mm_items.items() - } - else: - mm_hashes = None - - prompt_ids, mm_kwargs = self._cached_apply_hf_processor( - prompt_text, - mm_items, - hf_processor_mm_kwargs, - ) - - unbound_prompt_repls = self._get_prompt_replacements( - mm_items, - hf_processor_mm_kwargs, - mm_kwargs, - ) - mm_prompt_repls = self._bind_and_group_repls(unbound_prompt_repls) - - mm_item_counts = mm_items.get_all_counts() - self._validate_mm_kwargs(mm_kwargs, mm_item_counts) - - hf_mm_placeholders = self._find_mm_placeholders( - mm_prompt_repls, - prompt_ids, - mm_item_counts, - ) - - if self._always_apply_prompt_replacements(): - mm_missing_repl_counts = mm_item_counts - mm_missing_repls = dict(mm_prompt_repls) - else: - mm_missing_repl_counts = self._validate_mm_placeholders( - hf_mm_placeholders, - mm_item_counts, - allow_missing=True, - ) - - mm_missing_repls = dict[str, list[BoundPromptReplacement]]() - for modality, missing_repl_count in mm_missing_repl_counts.items(): - if missing_repl_count == 0: - mm_missing_repls[modality] = [] - elif missing_repl_count == mm_item_counts.get(modality, 0): - mm_missing_repls[modality] = mm_prompt_repls[modality] - else: - raise ValueError("Partial prompt replacement within " - f"{modality=} is not supported") - - # If HF processor already inserts placeholder tokens, - # there is no need for us to insert them - if all(len(repls) == 0 for repls in mm_missing_repls.items()): - tokenizer = self.info.get_tokenizer() - prompt_text = decode_tokens(tokenizer, prompt_ids) - mm_placeholders = hf_mm_placeholders - else: - ( - prompt_ids, - prompt_text, - missing_mm_placeholders, - ) = self._apply_prompt_replacements( - prompt_ids, - mm_missing_repls, - mm_missing_repl_counts, - ) - - mm_placeholders = {**hf_mm_placeholders, **missing_mm_placeholders} - - self._validate_mm_placeholders(mm_placeholders, mm_item_counts) - - mm_placeholder_ranges = { - modality: [item.to_range() for item in placeholders] - for modality, placeholders in mm_placeholders.items() - } - - return MultiModalInputsV2( - type="multimodal", - prompt=prompt_text, - prompt_token_ids=prompt_ids, - mm_kwargs=mm_kwargs, - mm_hashes=mm_hashes, - mm_placeholders=mm_placeholder_ranges, - ) diff --git a/vllm/multimodal/profiler.py b/vllm/multimodal/profiler.py deleted file mode 100644 index d1984c6b0546c..0000000000000 --- a/vllm/multimodal/profiler.py +++ /dev/null @@ -1,136 +0,0 @@ -from collections.abc import Mapping -from typing import Generic, TypeVar - -from vllm import envs -from vllm.inputs import DummyData -from vllm.logger import init_logger - -from .inputs import MultiModalInputsV2 -from .processing import BaseProcessingInfo -from .processor import BaseMultiModalProcessor -from .profiling import BaseDummyInputsBuilder - -logger = init_logger(__name__) - -_I = TypeVar("_I", bound=BaseProcessingInfo) - - -class MultiModalProfiler(Generic[_I]): - """ - Contains code for running memory profiling for multi-modal models. - """ - - def __init__( - self, - processor: BaseMultiModalProcessor[_I], - ) -> None: - super().__init__() - - self.processor = processor - - @property - def processing_info(self) -> BaseProcessingInfo: - return self.processor.info - - @property - def dummy_inputs(self) -> BaseDummyInputsBuilder[_I]: - return self.processor.dummy_inputs - - def _get_mm_limits(self) -> Mapping[str, int]: - mm_config = self.processing_info.ctx.get_mm_config() - mm_limit_per_prompt = mm_config.limit_per_prompt - - supported_mm_limits = self.processing_info.get_supported_mm_limits() - - mm_limits = { - modality: mm_limit_per_prompt.get(modality, 1) - for modality in supported_mm_limits - } - - for modality, supported_limit in supported_mm_limits.items(): - limit = mm_limits[modality] - if supported_limit is not None and supported_limit < limit: - raise ValueError( - f"You set {modality}={limit} (or defaulted to 1) in " - f"`--limit-mm-per-prompt`, but this model only supports " - f"at most {supported_limit} {modality} items.") - - return mm_limits - - def _get_dummy_mm_inputs( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> MultiModalInputsV2: - factory = self.dummy_inputs - processor_inputs = factory.get_dummy_processor_inputs( - seq_len, mm_counts) - - return self.processor.apply( - prompt_text=processor_inputs.prompt_text, - mm_data=processor_inputs.mm_data, - hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, - ) - - def get_dummy_data(self, seq_len: int) -> DummyData: - # Avoid circular import - from vllm.sequence import SequenceData - - mm_counts = self._get_mm_limits() - - info = self.processing_info - mm_max_tokens_per_item = info.get_mm_max_tokens_per_item(seq_len) - - if mm_counts.keys() != mm_max_tokens_per_item.keys(): - raise AssertionError( - "The keys returned by `get_supported_mm_limits`" - f"({set(mm_counts.keys())}) should be the same as those " - "returned by `get_mm_max_tokens_per_item` " - f"({set(mm_max_tokens_per_item.keys())})") - - mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts) - prompt_token_ids = mm_inputs["prompt_token_ids"] - placeholders_by_modality = mm_inputs["mm_placeholders"] - - total_placeholders_by_modality = { - modality: sum(item["length"] for item in placeholders) - for modality, placeholders in placeholders_by_modality.items() - } - expected_placeholders_by_modality = { - modality: mm_max_tokens_per_item[modality] * mm_counts[modality] - for modality in placeholders_by_modality - } - if total_placeholders_by_modality != expected_placeholders_by_modality: - raise AssertionError( - f"The processed dummy data has a total of " - f"{total_placeholders_by_modality} placeholder tokens, which " - f"is not the expected {expected_placeholders_by_modality} " - "tokens.") - - total_len = len(prompt_token_ids) - - # V0 does not support chunked prefill. - if total_len > seq_len and not envs.VLLM_USE_V1: - logger.warning( - "The context length (%d) of the model is too short " - "to hold the multi-modal embeddings in the worst case " - "(%d tokens in total, out of which %s are reserved for " - "multi-modal embeddings). This may cause certain multi-modal " - "inputs to fail during inference, even when the input text is " - "short. To avoid this, you should increase `max_model_len`, " - "reduce `max_num_seqs`, and/or reduce `mm_counts`.", seq_len, - total_len, total_placeholders_by_modality) - - return DummyData( - seq_data=SequenceData.from_prompt_token_counts((0, seq_len)), - multi_modal_data=None, - multi_modal_placeholders=None, - ) - - prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids))) - - return DummyData( - seq_data=SequenceData.from_seqs(prompt_token_ids), - multi_modal_data=mm_inputs["mm_kwargs"], - multi_modal_placeholders=placeholders_by_modality, - ) diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index de68b335cebd8..2ac3a6bcf3ddd 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -7,8 +7,14 @@ import numpy.typing as npt from PIL import Image -from .inputs import MultiModalDataDict -from .processing import BaseProcessingInfo +import vllm.envs as envs +from vllm.inputs import DummyData +from vllm.logger import init_logger + +from .inputs import MultiModalDataDict, MultiModalInputsV2 +from .processing import BaseMultiModalProcessor, BaseProcessingInfo + +logger = init_logger(__name__) @dataclass @@ -74,3 +80,124 @@ def _get_dummy_videos( ) -> list[npt.NDArray]: video = np.zeros((num_frames, width, height, 3)) return [video] * num_videos + + +class MultiModalProfiler(Generic[_I]): + """ + Contains code for running memory profiling for multi-modal models. + """ + + def __init__( + self, + processor: BaseMultiModalProcessor[_I], + ) -> None: + super().__init__() + + self.processor = processor + + @property + def processing_info(self) -> BaseProcessingInfo: + return self.processor.info + + @property + def dummy_inputs(self) -> BaseDummyInputsBuilder[_I]: + return self.processor.dummy_inputs + + def _get_mm_limits(self) -> Mapping[str, int]: + mm_config = self.processing_info.ctx.get_mm_config() + mm_limit_per_prompt = mm_config.limit_per_prompt + + supported_mm_limits = self.processing_info.get_supported_mm_limits() + + mm_limits = { + modality: mm_limit_per_prompt.get(modality, 1) + for modality in supported_mm_limits + } + + for modality, supported_limit in supported_mm_limits.items(): + limit = mm_limits[modality] + if supported_limit is not None and supported_limit < limit: + raise ValueError( + f"You set {modality}={limit} (or defaulted to 1) in " + f"`--limit-mm-per-prompt`, but this model only supports " + f"at most {supported_limit} {modality} items.") + + return mm_limits + + def _get_dummy_mm_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalInputsV2: + factory = self.dummy_inputs + processor_inputs = factory.get_dummy_processor_inputs( + seq_len, mm_counts) + + return self.processor.apply( + prompt_text=processor_inputs.prompt_text, + mm_data=processor_inputs.mm_data, + hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, + ) + + def get_dummy_data(self, seq_len: int) -> DummyData: + # Avoid circular import + from vllm.sequence import SequenceData + + mm_counts = self._get_mm_limits() + + info = self.processing_info + mm_max_tokens_per_item = info.get_mm_max_tokens_per_item(seq_len) + + if mm_counts.keys() != mm_max_tokens_per_item.keys(): + raise AssertionError( + "The keys returned by `get_supported_mm_limits`" + f"({set(mm_counts.keys())}) should be the same as those " + "returned by `get_mm_max_tokens_per_item` " + f"({set(mm_max_tokens_per_item.keys())})") + + mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts) + prompt_token_ids = mm_inputs["prompt_token_ids"] + placeholders_by_modality = mm_inputs["mm_placeholders"] + + total_placeholders_by_modality = { + modality: sum(item["length"] for item in placeholders) + for modality, placeholders in placeholders_by_modality.items() + } + expected_placeholders_by_modality = { + modality: mm_max_tokens_per_item[modality] * mm_counts[modality] + for modality in placeholders_by_modality + } + if total_placeholders_by_modality != expected_placeholders_by_modality: + raise AssertionError( + f"The processed dummy data has a total of " + f"{total_placeholders_by_modality} placeholder tokens, which " + f"is not the expected {expected_placeholders_by_modality} " + "tokens.") + + total_len = len(prompt_token_ids) + + # V0 does not support chunked prefill. + if total_len > seq_len and not envs.VLLM_USE_V1: + logger.warning( + "The context length (%d) of the model is too short " + "to hold the multi-modal embeddings in the worst case " + "(%d tokens in total, out of which %s are reserved for " + "multi-modal embeddings). This may cause certain multi-modal " + "inputs to fail during inference, even when the input text is " + "short. To avoid this, you should increase `max_model_len`, " + "reduce `max_num_seqs`, and/or reduce `mm_counts`.", seq_len, + total_len, total_placeholders_by_modality) + + return DummyData( + seq_data=SequenceData.from_prompt_token_counts((0, seq_len)), + multi_modal_data=None, + multi_modal_placeholders=None, + ) + + prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids))) + + return DummyData( + seq_data=SequenceData.from_seqs(prompt_token_ids), + multi_modal_data=mm_inputs["mm_kwargs"], + multi_modal_placeholders=placeholders_by_modality, + ) diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 074182ec1173c..5f01eac4edade 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -15,8 +15,8 @@ from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc from .image import ImagePlugin from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors -from .processing import ProcessingCache -from .processor import BaseMultiModalProcessor, BaseProcessingInfo +from .processing import (BaseMultiModalProcessor, BaseProcessingInfo, + ProcessingCache) from .profiling import BaseDummyInputsBuilder from .utils import cached_get_tokenizer from .video import VideoPlugin From 90c2547f1044695db18bfcfdb95647f27baf91e7 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 8 Jan 2025 09:01:09 +0000 Subject: [PATCH 7/7] Fix Phi3V and Qwen2-VL tests Signed-off-by: DarkLight1337 --- .../vision_language/processing/test_phi3v.py | 24 ++++++++----------- .../processing/test_qwen2_vl.py | 22 +++++++---------- vllm/model_executor/models/phi3v.py | 6 ++++- vllm/model_executor/models/qwen2_vl.py | 14 +++++++++-- 4 files changed, 35 insertions(+), 31 deletions(-) diff --git a/tests/models/decoder_only/vision_language/processing/test_phi3v.py b/tests/models/decoder_only/vision_language/processing/test_phi3v.py index 249045b3c04ce..c5b77260c6544 100644 --- a/tests/models/decoder_only/vision_language/processing/test_phi3v.py +++ b/tests/models/decoder_only/vision_language/processing/test_phi3v.py @@ -1,21 +1,13 @@ """Tests for phi3v's multimodal preprocessing kwargs.""" import pytest -from transformers import AutoTokenizer -from vllm.inputs import InputProcessingContext -from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.utils import cached_get_tokenizer from .....conftest import _ImageAssets from ....utils import build_model_context -# Wrap lazy imports to avoid initializing CUDA during test collection -@pytest.fixture() -def processor_for_phi3v(): - from vllm.model_executor.models.phi3v import Phi3VMultiModalProcessor - return Phi3VMultiModalProcessor - - @pytest.mark.parametrize("model_id", ["microsoft/Phi-3.5-vision-instruct"]) # yapf: disable @pytest.mark.parametrize( @@ -29,7 +21,6 @@ def processor_for_phi3v(): # yapf: enable @pytest.mark.parametrize("num_imgs", [1, 2]) def test_processor_override( - processor_for_phi3v, image_assets: _ImageAssets, model_id: str, mm_processor_kwargs: dict[str, int], @@ -37,21 +28,26 @@ def test_processor_override( num_imgs: int, ): """Ensure input_processor_for_phi3v handles num_crops properly.""" + # Avoid initializing CUDA early + from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID + ctx = build_model_context( model_name=model_id, tokenizer_name=model_id, trust_remote_code=True, limit_mm_per_prompt={"image": num_imgs}, ) - tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) - ctx = InputProcessingContext(ctx.model_config, tokenizer) + tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer) + processor = MULTIMODAL_REGISTRY.create_processor( + ctx.model_config, + tokenizer=tokenizer, + ) # Build the image str / prompt based on the number of images we pass img_str = "".join([f"<|image_{idx}|>\n" for idx in range(1, num_imgs + 1)]) prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n" mm_data = {"image": [image_assets[0].pil_image] * num_imgs} - processor = processor_for_phi3v(ctx) processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs) # Ensure we have the right number of placeholders per num_crops size diff --git a/tests/models/decoder_only/vision_language/processing/test_qwen2_vl.py b/tests/models/decoder_only/vision_language/processing/test_qwen2_vl.py index b9ac887edf90f..0d54802f2b733 100644 --- a/tests/models/decoder_only/vision_language/processing/test_qwen2_vl.py +++ b/tests/models/decoder_only/vision_language/processing/test_qwen2_vl.py @@ -1,19 +1,12 @@ import pytest -from transformers import AutoTokenizer -from vllm.inputs import InputProcessingContext +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.utils import cached_get_tokenizer from .....conftest import _ImageAssets from ....utils import build_model_context -# Fixtures lazy import to avoid initializing CUDA during test collection -@pytest.fixture() -def processor_for_qwen2_vl(): - from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalProcessor - return Qwen2VLMultiModalProcessor - - @pytest.mark.parametrize("model_id", ["Qwen/Qwen2-VL-2B-Instruct"]) # yapf: disable @pytest.mark.parametrize( @@ -24,7 +17,6 @@ def processor_for_qwen2_vl(): # yapf: enable @pytest.mark.parametrize("num_imgs", [1, 2]) def test_processor_override( - processor_for_qwen2_vl, image_assets: _ImageAssets, model_id: str, mm_processor_kwargs: dict[str, object], @@ -39,18 +31,20 @@ def test_processor_override( mm_processor_kwargs=None, limit_mm_per_prompt={"image": num_imgs}, ) - tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) - ctx = InputProcessingContext(ctx.model_config, tokenizer) + tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer) + processor = MULTIMODAL_REGISTRY.create_processor( + ctx.model_config, + tokenizer=tokenizer, + ) # Build the image str / prompt based on the number of images we pass prompt = "<|vision_start|><|image_pad|><|vision_end|>" * num_imgs mm_data = {"image": [image_assets[0].pil_image] * num_imgs} - processor = processor_for_qwen2_vl(ctx) processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs) # Ensure we have the right number of placeholders per num_crops size - hf_processor = processor._get_hf_processor(**mm_processor_kwargs) + hf_processor = processor.info.get_hf_processor(**mm_processor_kwargs) image_token_id = tokenizer.convert_tokens_to_ids(hf_processor.image_token) img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id) pixel_shape = processed_inputs["mm_kwargs"]["pixel_values"].shape diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 624f805bb2844..a1b1af35604db 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -322,6 +322,7 @@ def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: max_image_tokens = self.get_num_image_tokens( image_width=target_width, image_height=target_height, + processor=None, ) return {"image": max_image_tokens} @@ -331,8 +332,10 @@ def get_num_image_tokens( *, image_width: int, image_height: int, + processor: Optional[ProcessorMixin], ) -> int: - processor = self.get_hf_processor() + if processor is None: + processor = self.get_hf_processor() return processor.calc_num_image_tokens_from_image_size( # type: ignore width=image_width, @@ -431,6 +434,7 @@ def get_replacement_phi3v(item_idx: int): num_image_tokens = self.info.get_num_image_tokens( image_width=image_size.width, image_height=image_size.height, + processor=hf_processor, ) return [_IMAGE_TOKEN_ID] * num_image_tokens + [bos_token_id] diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index ba7e8f838d0fe..8537fec854b6d 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -763,15 +763,17 @@ def _get_vision_info( image_height: int, num_frames: int = 1, do_resize: bool = True, + image_processor: Optional[Qwen2VLImageProcessor], ) -> tuple[ImageSize, int]: + if image_processor is None: + image_processor = self.get_image_processor() + hf_config = self.get_hf_config() vision_config = hf_config.vision_config patch_size = vision_config.patch_size merge_size = vision_config.spatial_merge_size temporal_patch_size = vision_config.temporal_patch_size - image_processor = self.get_image_processor() - if do_resize: resized_height, resized_width = smart_resize( height=image_height, @@ -800,10 +802,12 @@ def get_num_image_tokens( *, image_width: int, image_height: int, + image_processor: Optional[Qwen2VLImageProcessor], ) -> int: _, num_image_tokens = self._get_vision_info( image_width=image_width, image_height=image_height, + image_processor=image_processor, ) return num_image_tokens @@ -813,11 +817,13 @@ def get_num_video_tokens( image_width: int, image_height: int, num_frames: int, + image_processor: Optional[Qwen2VLImageProcessor], ) -> int: _, num_video_tokens = self._get_vision_info( image_width=image_width, image_height=image_height, num_frames=num_frames, + image_processor=image_processor, ) return num_video_tokens @@ -825,6 +831,7 @@ def get_image_size_with_most_features(self) -> ImageSize: max_image_size, _ = self._get_vision_info( image_width=9999999, image_height=9999999, + image_processor=None, ) return max_image_size @@ -834,6 +841,7 @@ def get_max_image_tokens(self) -> int: return self.get_num_image_tokens( image_width=target_width, image_height=target_height, + image_processor=None, ) def _get_max_video_frames(self, max_tokens: int) -> int: @@ -847,6 +855,7 @@ def _get_max_video_frames(self, max_tokens: int) -> int: image_width=target_width, image_height=target_height, num_frames=next_num_frames, + image_processor=None, ) if next_max_tokens > max_tokens: @@ -880,6 +889,7 @@ def get_max_video_tokens(self, seq_len: int) -> int: image_width=target_width, image_height=target_height, num_frames=self.get_num_frames_with_most_features(seq_len), + image_processor=None, )