diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index ab10b479cb58c..b2367060c6c1b 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -41,10 +41,10 @@ def test_iter_token_runs(token_ids, expected): print("result:", result) # Manually constructed results - assert result == expected + assert [item._asdict() for item in result] == expected # Invariants - assert sum(run_info["length"] for run_info in result) == len(token_ids) + assert sum(run_info.length for run_info in result) == len(token_ids) # yapf: disable diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 389c751bb9788..037603bca3d19 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1,10 +1,10 @@ import re from abc import ABC, abstractmethod +from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence from dataclasses import dataclass from functools import lru_cache from itertools import groupby -from typing import (Any, Callable, Generic, Iterable, Mapping, NamedTuple, - Optional, Sequence, TypeVar, Union) +from typing import Any, Generic, NamedTuple, Optional, Protocol, TypeVar, Union import numpy as np from transformers import BatchFeature @@ -18,72 +18,21 @@ MultiModalInputsV2, MultiModalKwargs, PlaceholderRange, VideoItem) - -def _encode( - tokenizer: AnyTokenizer, - text: str, - *, - add_special_tokens: bool = False, -) -> list[int]: - """ - Backend-agnostic equivalent of HF's - :code:`tokenizer.encode(text, add_special_tokens=...)`. - """ - if isinstance(tokenizer, MistralTokenizer): - return tokenizer.tokenizer.encode(text, - bos=add_special_tokens, - eos=add_special_tokens) - - return tokenizer.encode(text, add_special_tokens=add_special_tokens) - - -@lru_cache(maxsize=2048) -def _cached_encode( - tokenizer: AnyTokenizer, - text: str, - *, - add_special_tokens: bool = False, -) -> list[int]: - return _encode(tokenizer, text, add_special_tokens=add_special_tokens) - - -def _decode( - tokenizer: AnyTokenizer, - token_ids: list[int], - *, - skip_special_tokens: bool = False, -) -> str: - """ - Backend-agnostic equivalent of HF's - :code:`tokenizer.decode(token_ids, skip_special_tokens=...)`. - """ - return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens) - - -@lru_cache(maxsize=2048) -def _cached_decode( - tokenizer: AnyTokenizer, - token_ids: tuple[int, ...], - *, - skip_special_tokens: bool = False, -) -> str: - return _decode(tokenizer, - list(token_ids), - skip_special_tokens=skip_special_tokens) - - PromptSegment: TypeAlias = Union[str, list[int]] def bind_segment( - prompt_segment: PromptSegment, + segment: PromptSegment, tokenizer: AnyTokenizer, ) -> "_BoundPromptSegment": + """ + Bind a text or token prompt to a tokenizer so that it can be + lazily converted into the other format on demand. + """ return _BoundPromptSegment( tokenizer=tokenizer, - _text=prompt_segment if isinstance(prompt_segment, str) else None, - _token_ids=prompt_segment - if isinstance(prompt_segment, list) else None, + _text=segment if isinstance(segment, str) else None, + _token_ids=segment if isinstance(segment, list) else None, ) @@ -163,6 +112,78 @@ class MultiModalProcessingMetadataBuiltins(TypedDict, total=False): """ +def _encode( + tokenizer: AnyTokenizer, + text: str, + *, + add_special_tokens: bool = False, +) -> list[int]: + """ + Backend-agnostic equivalent of HF's + :code:`tokenizer.encode(text, add_special_tokens=...)`. + """ + if isinstance(tokenizer, MistralTokenizer): + return tokenizer.tokenizer.encode(text, + bos=add_special_tokens, + eos=add_special_tokens) + + return tokenizer.encode(text, add_special_tokens=add_special_tokens) + + +@lru_cache(maxsize=2048) +def _cached_encode( + tokenizer: AnyTokenizer, + text: str, + *, + add_special_tokens: bool = False, +) -> list[int]: + return _encode(tokenizer, text, add_special_tokens=add_special_tokens) + + +def _decode( + tokenizer: AnyTokenizer, + token_ids: list[int], + *, + skip_special_tokens: bool = False, +) -> str: + """ + Backend-agnostic equivalent of HF's + :code:`tokenizer.decode(token_ids, skip_special_tokens=...)`. + """ + return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens) + + +@lru_cache(maxsize=2048) +def _cached_decode( + tokenizer: AnyTokenizer, + token_ids: tuple[int, ...], + *, + skip_special_tokens: bool = False, +) -> str: + return _decode(tokenizer, + list(token_ids), + skip_special_tokens=skip_special_tokens) + + +class _HasModalityAttr(Protocol): + modality: str + + +class _HasModalityProp(Protocol): + + @property + def modality(self) -> str: + ... + + +_M = TypeVar("_M", bound=Union[_HasModalityAttr, _HasModalityProp]) + + +def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]: + """Convenience function to apply :func:`full_groupby` based on modality.""" + return full_groupby(values, key=lambda x: x.modality) + + @dataclass class _BoundPromptSegment: tokenizer: AnyTokenizer @@ -237,7 +258,7 @@ def to_multi_format(data: MultiModalDataDict) -> dict[str, list[Any]]: return multi_data -class _TokenRun(TypedDict): +class _TokenRun(NamedTuple): token_id: int start_idx: int @@ -257,7 +278,7 @@ def iter_token_runs(token_ids: list[int]) -> Iterable[_TokenRun]: start_idx += length -class _BoundPlaceholderRange(TypedDict): +class _BoundPlaceholderRange(NamedTuple): modality: str offset: int length: int @@ -265,11 +286,12 @@ class _BoundPlaceholderRange(TypedDict): def iter_placeholders( prompt_repls: Sequence[_BoundPromptReplacement[Any]], - new_token_ids: list[int], + token_ids: list[int], *, min_placeholder_count: int, ) -> Iterable[_BoundPlaceholderRange]: - repls_by_modality = full_groupby(prompt_repls, key=lambda x: x.modality) + """Yield each set of placeholder tokens found in :code:`token_ids`.""" + repls_by_modality = full_groupby_modality(prompt_repls) placeholder_ids_by_modality = { modality: { @@ -280,15 +302,15 @@ def iter_placeholders( for modality, repls in repls_by_modality } - for run_info in iter_token_runs(new_token_ids): - if run_info["length"] > min_placeholder_count: + for run_info in iter_token_runs(token_ids): + if run_info.length > min_placeholder_count: for (modality, placeholder_ids) in placeholder_ids_by_modality.items(): - if run_info["token_id"] in placeholder_ids: + if run_info.token_id in placeholder_ids: yield _BoundPlaceholderRange( modality=modality, - offset=run_info["start_idx"], - length=run_info["length"], + offset=run_info.start_idx, + length=run_info.length, ) @@ -398,6 +420,7 @@ def find_token_matches( prompt: list[int], prompt_repls: Sequence[_BoundPromptReplacement[_T]], ) -> list[_PromptReplacementTokenMatch[_T]]: + """Return each target of :code:`prompt_repls` found in :code:`prompt`.""" return [ _PromptReplacementTokenMatch(prompt_repl, match) for prompt_repl in prompt_repls @@ -409,6 +432,7 @@ def find_text_matches( prompt: str, prompt_repls: Sequence[_BoundPromptReplacement[_T]], ) -> list[_PromptReplacementTextMatch[_T]]: + """Return each target of :code:`prompt_repls` found in :code:`prompt`.""" return [ _PromptReplacementTextMatch(prompt_repl, match) for prompt_repl in prompt_repls @@ -416,10 +440,14 @@ def find_text_matches( ] -def unique_sort_matches( +def resolve_matches( prompt: _S, matches: Sequence[_PromptReplacementMatch[_T, _S]], ) -> list[_PromptReplacementMatch[_T, _S]]: + """ + Resolve :code:`matches` to ensure that there are no overlapping matches, + and sort them such that earlier matches take priority over later ones. + """ num_matches_by_idx = np.zeros(len(prompt), dtype=int) for match in matches: num_matches_by_idx[match.start_idx:match.end_idx] += 1 @@ -443,8 +471,7 @@ def _replace_matches( prev_end_idx = 0 next_idx_by_modality = {modality: 0 for modality in mm_items_by_modality} - # Earlier matches take priority over later ones - for match in unique_sort_matches(prompt, matches): + for match in resolve_matches(prompt, matches): modality = match.modality mm_items = mm_items_by_modality[modality] @@ -471,6 +498,7 @@ def replace_token_matches( mm_items_by_modality: Mapping[str, list[_T]], hf_inputs: BatchFeature, ) -> list[int]: + """Apply :code:`prompt_repls` to :code:`prompt`.""" if not matches: return prompt @@ -490,6 +518,7 @@ def replace_text_matches( mm_items_by_modality: Mapping[str, list[_T]], hf_inputs: BatchFeature, ) -> str: + """Apply :code:`prompt_repls` to :code:`prompt`.""" if not matches: return prompt @@ -581,8 +610,7 @@ def _apply_prompt_replacements( if all( len(matches) >= len(mm_data[modality]) - for modality, matches in full_groupby(token_matches, - key=lambda x: x.modality) + for modality, matches in full_groupby_modality(token_matches) ): # yapf: disable token_ids = replace_token_matches( token_ids, @@ -647,11 +675,10 @@ def apply( mm_placeholders = { modality: [ - PlaceholderRange(offset=item["offset"], length=item["length"]) + PlaceholderRange(offset=item.offset, length=item.length) for item in items ] - for modality, items in full_groupby(all_placeholders, - key=lambda x: x["modality"]) + for modality, items in full_groupby_modality(all_placeholders) } return MultiModalInputsV2(