diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index 4a57dc9a94d61..e565b65f71d2c 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -1,9 +1,6 @@ import pytest -from transformers import PreTrainedTokenizerBase -from vllm.multimodal.processing import (find_token_match_by_text, - iter_token_runs, replace_by_text) -from vllm.multimodal.utils import cached_get_tokenizer +from vllm.multimodal.processing import iter_token_matches, iter_token_runs # yapf: disable @@ -11,17 +8,20 @@ ("token_ids", "expected"), [ ([], []), - ([32000, 32000, 32000], [(32000, { "offset": 0, "length": 3 })]), + ( + [32000, 32000, 32000], + [{ "token_id": 32000, "start_idx": 0, "length": 3 }], + ), ( [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918], [ - (9833, { "offset": 0, "length": 1 }), - (28747, { "offset": 1, "length": 1 }), - (32000, { "offset": 2, "length": 3 }), - (9833, { "offset": 5, "length": 1 }), - (28747, { "offset": 6, "length": 1 }), - (32000, { "offset": 7, "length": 2 }), - (918, { "offset": 9, "length": 1 }), + { "token_id": 9833, "start_idx": 0, "length": 1 }, + { "token_id": 28747, "start_idx": 1, "length": 1 }, + { "token_id": 32000, "start_idx": 2, "length": 3 }, + { "token_id": 9833, "start_idx": 5, "length": 1 }, + { "token_id": 28747, "start_idx": 6, "length": 1 }, + { "token_id": 32000, "start_idx": 7, "length": 2 }, + { "token_id": 918, "start_idx": 9, "length": 1 }, ], ), ], @@ -30,155 +30,71 @@ def test_iter_token_runs(token_ids, expected): result = list(iter_token_runs(token_ids)) - # Invariants - assert sum(run_info["length"] for _, run_info in result) == len(token_ids) - # Manually constructed results assert result == expected - -@pytest.mark.parametrize("tokenizer_id", [ - "llava-hf/llava-1.5-7b-hf", - "meta-llama/Llama-3.2-11B-Vision-Instruct", - "microsoft/Phi-3.5-vision-instruct", - "Qwen/Qwen2-VL-2B-Instruct", -]) -@pytest.mark.parametrize( - "text", - [ - "What is in this image?", - # LLaVA - "What is in this image?", - "What isin this image?", - "What is in this image?", - # LLama-3.2 - "<|image|>What is in this image?", - "What is<|image|>in this image?", - "What is in this image?<|image|>", - # Phi-3-vision - "What is in this image?", - "What isin this image?", - "What is in this image?", - # Qwen2-VL - "<|vision_start|><|image_pad|><|vision_end|>What is in this image?", - "What is<|vision_start|><|image_pad|><|vision_end|>in this image?", - "What is in this image?<|vision_start|><|image_pad|><|vision_end|>", - ]) -@pytest.mark.parametrize( - "match_str", - [ - # No match - "No", - # Has match - "i", - "What", - "What is", - "image", - "image?", - "", - "<|image|>", - "", - "<|vision_start|><|image_pad|><|vision_end|>", - "", - "", - ]) -@pytest.mark.parametrize("add_special_tokens", [True, False]) -def test_token_match_by_text( - tokenizer_id, - text, - match_str, - add_special_tokens, -): - tokenizer = cached_get_tokenizer(tokenizer_id) - assert isinstance(tokenizer, PreTrainedTokenizerBase) - - token_ids = tokenizer.encode(text, add_special_tokens=add_special_tokens) - match = find_token_match_by_text(tokenizer, token_ids, text, match_str) - - # These are only shown in the output if the test fails - print("token_ids:", token_ids) - print("match:", match) - # Invariants - if (match_str in text - or match_str in tokenizer.decode(token_ids, - skip_special_tokens=False)): - assert match is not None - match_start_idx, match_end_idx, *_ = match - - assert match_str in tokenizer.decode( - token_ids[match_start_idx:match_end_idx], - skip_special_tokens=False, - ) - assert match_str not in tokenizer.decode( - token_ids[match_start_idx + 1:match_end_idx], - skip_special_tokens=False, - ) - assert match_str not in tokenizer.decode( - token_ids[match_start_idx:match_end_idx - 1], - skip_special_tokens=False, - ) - else: - assert match is None + assert sum(run_info["length"] for run_info in result) == len(token_ids) -@pytest.mark.parametrize("tokenizer_id", ["llava-hf/llava-1.5-7b-hf"]) -@pytest.mark.parametrize(("input_text", "replacement_count", "expected_text"), - [ - ("foo", 0, ""), - ("bar", 0, "bar"), - ("food", 0, "d"), - ("foo", 1, "bar"), - ("bar", 1, "bar"), - ("food", 1, "bard"), - ("foo", 2, "barbar"), - ("bar", 2, "bar"), - ("food", 2, "barbard"), - ]) -@pytest.mark.parametrize("add_special_tokens", [True, False]) -def test_replace_by_text( - tokenizer_id, - input_text, - replacement_count, - expected_text, - add_special_tokens, -): - tokenizer = cached_get_tokenizer(tokenizer_id) - assert isinstance(tokenizer, PreTrainedTokenizerBase) - - vocab = tokenizer.get_vocab() - missing_tokens = {"▁foo", "▁bar", "▁food"} - vocab.keys() - assert not missing_tokens, missing_tokens - assert "▁bard" not in vocab - - input_ids = tokenizer.encode(input_text, - add_special_tokens=add_special_tokens) - bar_id = vocab["bar"] - - output_ids, output_text, replacement = replace_by_text( - tokenizer, - input_ids[:], # Copy - input_text, - "foo", - bar_id, - replacement_count, - ) +# yapf: disable +@pytest.mark.parametrize( + ("token_ids", "match_ids", "expected"), + [ + ([], [], [{ "start_idx": 0, "end_idx": 0 }]), + ([], [32000], []), + ( + [32000, 32000, 32000], + [32000], + [ + { "start_idx": 0, "end_idx": 1 }, + { "start_idx": 1, "end_idx": 2 }, + { "start_idx": 2, "end_idx": 3 }, + ], + ), + ( + [32000, 32000, 32000], + [32000, 32000], + [ + { "start_idx": 0, "end_idx": 2 }, + { "start_idx": 1, "end_idx": 3 }, + ], + ), + ( + [32000, 32000, 32000], + [32000, 32000, 32000], + [{ "start_idx": 0, "end_idx": 3 }], + ), + ( + [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918], + [28747, 32000], + [ + { "start_idx": 1, "end_idx": 3 }, + { "start_idx": 6, "end_idx": 8 }, + ], + ), + ( + [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918], + [28747, 32000, 32000, 32000], + [ + { "start_idx": 1, "end_idx": 5 }, + ], + ), + ( + [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918], + [28747, 0, 32000], + [], + ), + ], +) +# yapf: enable +def test_iter_token_matches(token_ids, match_ids, expected): + result = list(iter_token_matches(token_ids, match_ids)) - # These are only shown in the output if the test fails - print("input_ids:", input_ids) - print("output_ids:", output_ids) - print("output_text:", output_text) - print("replacement:", replacement) + # Manually constructed results + assert [item._asdict() for item in result] == expected # Invariants - if replacement is None: - assert output_ids == input_ids - else: - offset = replacement["offset"] - repl_len = replacement["length"] - - assert output_ids[offset:offset + repl_len] == [bar_id] * repl_len - assert repl_len == replacement_count - - # Manually constructed results - assert output_text == expected_text + match_lens = [end - start for start, end in result] + print("match_lens:", match_lens) # Only displayed on error + assert all(match_len == len(match_ids) for match_len in match_lens) diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 47f1a6bfc4bf9..edc8064ba8725 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1,44 +1,182 @@ +import re from dataclasses import dataclass from functools import lru_cache from itertools import groupby -from typing import (Any, Callable, Generic, List, Mapping, NamedTuple, - Optional, TypeVar, Union, final) +from typing import (Any, Callable, Generic, Iterable, Mapping, NamedTuple, + Optional, Sequence, TypeVar, Union) from transformers import BatchFeature from typing_extensions import TypeAlias, TypedDict from vllm.inputs import InputProcessingContext from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer -from vllm.utils import is_list_of +from vllm.utils import full_groupby, is_list_of from .inputs import (AudioItem, ImageItem, MultiModalDataDict, MultiModalInputsV2, MultiModalKwargs, PlaceholderRange, VideoItem) + +def _encode( + tokenizer: AnyTokenizer, + text: str, + *, + add_special_tokens: bool = False, +) -> list[int]: + """ + Backend-agnostic equivalent of HF's + :code:`tokenizer.encode(text, add_special_tokens=...)`. + """ + if isinstance(tokenizer, MistralTokenizer): + return tokenizer.tokenizer.encode(text, + bos=add_special_tokens, + eos=add_special_tokens) + + return tokenizer.encode(text, add_special_tokens=add_special_tokens) + + +@lru_cache(maxsize=2048) +def _cached_encode( + tokenizer: AnyTokenizer, + text: str, + *, + add_special_tokens: bool = False, +) -> list[int]: + return _encode(tokenizer, text, add_special_tokens=add_special_tokens) + + +def _decode( + tokenizer: AnyTokenizer, + token_ids: list[int], + *, + skip_special_tokens: bool = False, +) -> str: + """ + Backend-agnostic equivalent of HF's + :code:`tokenizer.decode(token_ids, skip_special_tokens=...)`. + """ + return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens) + + +@lru_cache(maxsize=2048) +def _cached_decode( + tokenizer: AnyTokenizer, + token_ids: tuple[int, ...], + *, + skip_special_tokens: bool = False, +) -> str: + return _decode(tokenizer, + list(token_ids), + skip_special_tokens=skip_special_tokens) + + +PromptSegment: TypeAlias = Union[str, list[int]] + + +def bind_segment( + prompt_segment: PromptSegment, + tokenizer: AnyTokenizer, +) -> "_BoundPromptSegment": + return _BoundPromptSegment( + tokenizer=tokenizer, + _text=prompt_segment if isinstance(prompt_segment, str) else None, + _token_ids=prompt_segment + if isinstance(prompt_segment, list) else None, + ) + + +_S_co = TypeVar("_S_co", bound=PromptSegment, covariant=True) _T = TypeVar("_T") -class PlaceholderReplacement(TypedDict, Generic[_T]): - token_id: int - """The ID of the placeholder token.""" +@dataclass +class PromptReplacement(Generic[_S_co, _T]): + target: _S_co + """The prompt segment to find and replace.""" - count: Union[Callable[[_T, BatchFeature, int], int], int] + repl_unit: _S_co + """ + The unit making up the replacement prompt segment. + + See :code:`repl_count` for more details. """ - Given the original data item, HF-processed data, and index of the processed - item, output the number of replacement tokens to be allocated in vLLM. - For convenience, you can pass in an integer if this number is a constant. + repl_count: Callable[[_T, BatchFeature, int], int] """ + Given the original data item, HF-processed data, and index of the processed + item, output the number of repetitions of :code:`repl_unit` to build up the + replacement prompt segment. + """ + + def bind( + self, + modality: str, + tokenizer: AnyTokenizer, + ) -> "_BoundPromptReplacement[_T]": + return _BoundPromptReplacement( + modality=modality, + target=bind_segment(self.target, tokenizer), + repl_unit=bind_segment(self.repl_unit, tokenizer), + repl_count=self.repl_count, + ) + + +@dataclass +class _BoundPromptSegment: + tokenizer: AnyTokenizer + _text: Optional[str] + _token_ids: Optional[list[int]] + + def __post_init__(self) -> None: + if self._text is None and self._token_ids is None: + raise ValueError("At least one of 'text' and 'token_ids' must be " + "specified") + + @property + def text(self) -> str: + if self._text is None: + assert self._token_ids is not None + self._text = _cached_decode(self.tokenizer, tuple(self._token_ids)) + + return self._text + + @property + def token_ids(self) -> list[int]: + if self._token_ids is None: + assert self._text is not None + self._token_ids = _cached_encode(self.tokenizer, self._text) + + return self._token_ids + + +@dataclass +class _BoundPromptReplacement(Generic[_T]): + modality: str + target: _BoundPromptSegment + repl_unit: _BoundPromptSegment + repl_count: Callable[[_T, BatchFeature, int], int] @dataclass class ModalityProcessingMetadata(Generic[_T]): - placeholder_replacements: Mapping[str, PlaceholderReplacement[_T]] + prompt_repls: Sequence[PromptReplacement[PromptSegment, _T]] """ - A dictionary that maps each substring to search in the original prompt text - to the corresponding replacement. + Defines each segment to replace in the HF-processed prompt. + + This is skipped if the HF-processed prompt is found to already contain + the replacement prompts. """ + def bind_prompt_repls( + self, + modality: str, + tokenizer: AnyTokenizer, + ) -> list[_BoundPromptReplacement[_T]]: + return [ + prompt_repl.bind(modality, tokenizer) + for prompt_repl in self.prompt_repls + ] + class MultiModalProcessingMetadataBuiltins(TypedDict, total=False): """Type annotations for modality types predefined by vLLM.""" @@ -60,46 +198,14 @@ class MultiModalProcessingMetadataBuiltins(TypedDict, total=False): Read more on that :ref:`here `. """ -MultiModalMultiData: TypeAlias = List[_T] -""" -A list of data items, where the number of data items allowed -per modality is restricted by :code:`--limit-mm-per-prompt`. -""" - -@final -class MultiModalMultiDataBuiltins(TypedDict, total=False): - """Type annotations for modality types predefined by vLLM.""" - - image: MultiModalMultiData[ImageItem] - """The input images.""" - - video: MultiModalMultiData[VideoItem] - """The input videos.""" - - audio: MultiModalMultiData[AudioItem] - """The input audios.""" - - -MultiModalMultiDataDict: TypeAlias = Mapping[str, MultiModalMultiData[Any]] -""" -A dictionary containing an entry for each modality type to input. - -Note: - This dictionary also accepts modality keys defined outside - :class:`MultiModalMultiDataBuiltins` as long as a customized plugin - is registered through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`. - Read more on that :ref:`here `. -""" - - -def to_multi_format(data: MultiModalDataDict) -> MultiModalMultiDataDict: +def to_multi_format(data: MultiModalDataDict) -> dict[str, list[Any]]: """ Convert a :class:`MultiModalDataDict` containing single data items to a :class:`MultiModalMultiDataDict` containing multiple data items per entry. """ - multi_data: Mapping[str, MultiModalMultiData[Any]] = {} + multi_data = dict[str, list[Any]]() for k, v in data.items(): # yapf: disable @@ -115,7 +221,14 @@ def to_multi_format(data: MultiModalDataDict) -> MultiModalMultiDataDict: return multi_data -def iter_token_runs(token_ids: List[int]): +class _TokenRun(TypedDict): + token_id: int + + start_idx: int + length: int + + +def iter_token_runs(token_ids: list[int]) -> Iterable[_TokenRun]: """ Yield the starting index and length of each run of tokens that are the same. """ @@ -123,221 +236,39 @@ def iter_token_runs(token_ids: List[int]): for token_id, it in groupby(token_ids): length = sum(1 for _ in it) - yield token_id, PlaceholderRange(offset=start_idx, length=length) + yield _TokenRun(token_id=token_id, start_idx=start_idx, length=length) start_idx += length -def _encode( - tokenizer: AnyTokenizer, - text: str, - *, - add_special_tokens: bool = False, -) -> List[int]: - """ - Backend-agnostic equivalent of HF's - :code:`tokenizer.encode(text, add_special_tokens=...)`. - """ - if isinstance(tokenizer, MistralTokenizer): - return tokenizer.tokenizer.encode(text, - bos=add_special_tokens, - eos=add_special_tokens) - - return tokenizer.encode(text, add_special_tokens=add_special_tokens) - - -_cached_encode = lru_cache(_encode) - - -@lru_cache -def _max_vocab_token_len(tokenizer: AnyTokenizer) -> int: - return max(len(token_text) for token_text in tokenizer.get_vocab()) - - class _TokenMatch(NamedTuple): start_idx: int end_idx: int + @property + def length(self) -> int: + return self.end_idx - self.start_idx -def find_token_match( - token_ids: List[int], - match_ids: List[int], -) -> Optional[_TokenMatch]: + +def iter_token_matches( + token_ids: list[int], + match_ids: list[int], +) -> Iterable[_TokenMatch]: """ - Find the first occurrence of :code:`match_ids` in :code:`token_ids`. + Yield each occurrence of :code:`match_ids` in :code:`token_ids`. """ match_len = len(match_ids) for start_idx in range(len(token_ids) - match_len + 1): end_idx = start_idx + match_len if token_ids[start_idx:end_idx] == match_ids: - return _TokenMatch(start_idx, end_idx) + yield _TokenMatch(start_idx, end_idx) - return None - -class _TokenMatchFromTextCandidate(NamedTuple): - start_idx: int - end_idx: int - - match_text_prefix: str - match_text_suffix: str - - @property - def distance(self) -> int: - return len(self.match_text_prefix) + len(self.match_text_suffix) - - -class _TokenMatchFromText(NamedTuple): - start_idx: int - end_idx: int - - match_prefix: List[int] - match_suffix: List[int] - - match_text_prefix: str - match_text_suffix: str - - -def find_token_match_by_text( - tokenizer: AnyTokenizer, - token_ids: List[int], - token_text: str, - match_text: str, -) -> Optional[_TokenMatchFromText]: - """ - Find the first occurrence of the tokenized :code:`match_text` in - :code:`token_ids`. - """ - match_ids = _cached_encode(tokenizer, match_text, add_special_tokens=False) - if (match := find_token_match(token_ids, match_ids)): - return _TokenMatchFromText( - match.start_idx, - match.end_idx, - match_prefix=[], - match_suffix=[], - match_text_prefix="", - match_text_suffix="", - ) - - # When `match_text` is not mapped to a special token ID, - # it may be tokenized differently based on the surrounding tokens - # as well as whether it is at the start/end of the string. - # Therefore, we need to use `token_text` as a reference. - text_start_idx = token_text.find(match_text) - if text_start_idx == -1: - return None - - text_end_idx = text_start_idx + len(match_text) - - # In case the left/right side of `match_text` is fused with the - # string immediately before/after it as a single token - text_buffer = _max_vocab_token_len(tokenizer) - 1 - left_text = token_text[:max(0, text_start_idx - text_buffer)] - right_text = token_text[:text_end_idx + text_buffer] - - left_idx = len(_encode(tokenizer, left_text, add_special_tokens=False)) - right_idx = len(_encode(tokenizer, right_text, add_special_tokens=True)) - window_size = len(match_ids) - - best_distance = len(token_text) - best_candidate = None - - for start_idx in range(left_idx, right_idx - window_size + 1): - end_idx = start_idx + window_size - candidate_text = tokenizer.decode( - token_ids[start_idx:end_idx], - # In case match_text is a special token - skip_special_tokens=False, - ) - - if match_text in candidate_text: - candidate = _TokenMatchFromTextCandidate( - start_idx, - end_idx, - *candidate_text.split(match_text, 1), - ) - - if candidate.distance < best_distance: - best_candidate = candidate - best_distance = candidate.distance - - if best_distance == 0: - break - - assert best_candidate is not None, dict( - # To facilitate debugging - token_ids=token_ids, - match_ids=match_ids, - left_text=left_text, - right_text=right_text, - left_idx=left_idx, - right_idx=right_idx, - ) - - match_token_prefix = _cached_encode( - tokenizer, - best_candidate.match_text_prefix, - add_special_tokens=False, - ) - match_token_suffix = _cached_encode( - tokenizer, - best_candidate.match_text_suffix, - add_special_tokens=False, - ) - - return _TokenMatchFromText( - start_idx=best_candidate.start_idx, - end_idx=best_candidate.end_idx, - match_prefix=match_token_prefix, - match_suffix=match_token_suffix, - match_text_prefix=best_candidate.match_text_prefix, - match_text_suffix=best_candidate.match_text_suffix, - ) - - -def replace_by_text( - tokenizer: AnyTokenizer, - token_ids: List[int], - token_text: str, - match_text: str, - replacement_id: int, - replacement_count: int, -) -> tuple[List[int], str, Optional[PlaceholderRange]]: - """ - Find the first occurrence of the tokenized :code:`match_text` in - :code:`token_ids`, and replace it with - :code:`[replacement_id] * replacement_count`. - - This function updates :code:`token_ids` in place. - """ - match = find_token_match_by_text( - tokenizer, - token_ids, - token_text, - match_text, - ) - - if match is None: - return token_ids, token_text, None - - start_idx, end_idx, prefix_ids, suffix_ids, prefix_str, suffix_str = match - - replacement_ids = (prefix_ids + [replacement_id] * replacement_count + - suffix_ids) - replacement_text = tokenizer.decode( - replacement_ids, - # In case match_text is a special token - skip_special_tokens=False, - ) - - token_ids[start_idx:end_idx] = replacement_ids - token_text = token_text.replace(prefix_str + match_text + suffix_str, - replacement_text, 1) - - return (token_ids, token_text, - PlaceholderRange(offset=start_idx + len(prefix_ids), - length=replacement_count)) +class _BoundPlaceholderRange(TypedDict): + modality: str + offset: int + length: int class MultiModalProcessor: @@ -363,6 +294,128 @@ def __call__( ) -> MultiModalInputsV2: return self.apply(prompt, mm_data, mm_processor_kwargs) + def _extract_placeholder_ranges( + self, + all_prompt_repls: Sequence[_BoundPromptReplacement[Any]], + new_token_ids: list[int], + *, + # To avoid false positives from multi-input when detecting + # whether HF processor already inserts placeholder tokens + min_placeholder_count: int = 16, + ) -> Iterable[_BoundPlaceholderRange]: + placeholder_ids_by_modality = { + modality: { + token_id + for prompt_repl in prompt_repls + for token_id in prompt_repl.repl_unit.token_ids + } + for modality, prompt_repls in full_groupby( + all_prompt_repls, key=lambda x: x.modality) + } + + # In case HF processor already inserts placeholder tokens + for run_info in iter_token_runs(new_token_ids): + if run_info["length"] > min_placeholder_count: + for (modality, + placeholder_ids) in placeholder_ids_by_modality.items(): + if run_info["token_id"] in placeholder_ids: + yield _BoundPlaceholderRange( + modality=modality, + offset=run_info["start_idx"], + length=run_info["length"], + ) + + def _find_token_id_matches( + self, + token_ids: list[int], + prompt_repls: Sequence[_BoundPromptReplacement[_T]], + ) -> list[tuple[str, _TokenMatch]]: + return [(prompt_repl.target.text, match) + for prompt_repl in prompt_repls for match in + iter_token_matches(token_ids, prompt_repl.target.token_ids)] + + def _replace_token_id_matches( + self, + token_ids: list[int], + prompt_repls: Sequence[_BoundPromptReplacement[_T]], + matches: Sequence[tuple[str, _TokenMatch]], + mm_items_by_modality: Mapping[str, list[_T]], + hf_inputs: BatchFeature, + ) -> list[int]: + prompt_repls_by_target_text = { + prompt_repl.target.text: prompt_repl + for prompt_repl in prompt_repls + } + + # To ensure that later replacements don't affect + # the placeholder ranges of earlier ones + sorted_matches = sorted(matches, key=lambda x: x[1].start_idx) + + out_token_ids = list[int]() + prev_end_idx = 0 + + for i, (target_text, (start_idx, + end_idx)) in enumerate(sorted_matches): + prompt_repl = prompt_repls_by_target_text[target_text] + mm_items = mm_items_by_modality[prompt_repl.modality] + + repl_count = prompt_repl.repl_count(mm_items[i], hf_inputs, i) + repl_ids = prompt_repl.repl_unit.token_ids * repl_count + + out_token_ids.extend(token_ids[prev_end_idx:start_idx] + repl_ids) + prev_end_idx = end_idx + + return out_token_ids + + def _iter_text_matches( + self, + token_text: str, + prompt_repls: Sequence[_BoundPromptReplacement[_T]], + ) -> Iterable[tuple[str, re.Match]]: + for prompt_repl in prompt_repls: + target_text = prompt_repl.target.text + for match in re.finditer(re.escape(target_text), token_text): + yield target_text, match + + def _find_and_replace_token_text_matches( + self, + tokenizer: AnyTokenizer, + token_ids: list[int], + prompt_repls: Sequence[_BoundPromptReplacement[_T]], + mm_items_by_modality: Mapping[str, list[_T]], + hf_inputs: BatchFeature, + ) -> list[int]: + token_text = _decode(tokenizer, token_ids) + + prompt_repls_by_target_text = { + prompt_repl.target.text: prompt_repl + for prompt_repl in prompt_repls + } + + # To ensure that later replacements don't affect + # the placeholder ranges of earlier ones + sorted_matches = sorted( + self._iter_text_matches(token_text, prompt_repls), + key=lambda x: x[1].start(), + ) + + out_texts = list[str]() + prev_end_idx = 0 + + for i, (target_text, match) in enumerate(sorted_matches): + prompt_repl = prompt_repls_by_target_text[target_text] + mm_items = mm_items_by_modality[prompt_repl.modality] + + repl_count = prompt_repl.repl_count(mm_items[i], hf_inputs, i) + repl_text = prompt_repl.repl_unit.text * repl_count + + out_texts.extend(token_text[prev_end_idx:match.start()] + + repl_text) + + prev_end_idx = match.end() + + return _encode(tokenizer, "".join(out_texts)) + def apply( self, prompt: str, @@ -372,68 +425,65 @@ def apply( tokenizer = self.ctx.tokenizer hf_processor = self.ctx.get_hf_processor() - processed_inputs = hf_processor( + hf_inputs = hf_processor( text=prompt, # type: ignore **mm_data, **mm_processor_kwargs, ) - new_token_ids, = processed_inputs.pop("input_ids").tolist() - mm_kwargs = MultiModalKwargs(processed_inputs) - - new_prompt = prompt - mm_placeholders: Mapping[str, List[PlaceholderRange]] = {} - - for modality, orig_inputs in to_multi_format(mm_data).items(): - assert isinstance(orig_inputs, list) - - metadata = self.metadata[modality] - placeholder_repls = metadata.placeholder_replacements - repl_token_ids = { - replacement["token_id"] - for replacement in placeholder_repls.values() - } - - modality_placeholders: List[PlaceholderRange] = [] - - # In case HF processor already inserts placeholder tokens - for new_token_id, run_info in iter_token_runs(new_token_ids): - if new_token_id in repl_token_ids: - modality_placeholders.append(run_info) - - if modality_placeholders: - new_prompt = tokenizer.decode(new_token_ids) - else: # Otherwise, we insert them ourselves - for item_idx, orig_item in enumerate(orig_inputs): - for match_str, replacement in placeholder_repls.items(): - replacement_count = replacement["count"] - if callable(replacement_count): - replacement_count = replacement_count( - orig_item, - processed_inputs, - item_idx, - ) - - ( - new_token_ids, - new_prompt, - placeholders, - ) = replace_by_text( - tokenizer, - new_token_ids, - new_prompt, - match_str, - replacement["token_id"], - replacement_count, - ) - - if placeholders is not None: - modality_placeholders.append(placeholders) - - mm_placeholders[modality] = modality_placeholders # type: ignore[index] # yapf: disable + new_token_ids, = hf_inputs.pop("input_ids").tolist() + mm_kwargs = MultiModalKwargs(hf_inputs) + + all_prompt_repls = [ + prompt_repl for modality, metadata in self.metadata.items() + if modality in mm_data + for prompt_repl in metadata.bind_prompt_repls(modality, tokenizer) + ] + + # In case HF processor already inserts placeholder tokens + all_placeholder_ranges = list( + self._extract_placeholder_ranges(all_prompt_repls, new_token_ids)) + + # Otherwise, we insert them ourselves + if not all_placeholder_ranges: + mm_items = to_multi_format(mm_data) + + token_id_matches = self._find_token_id_matches( + new_token_ids, + all_prompt_repls, + ) + if len(token_id_matches) == len(mm_items): + new_token_ids = self._replace_token_id_matches( + new_token_ids, + all_prompt_repls, + token_id_matches, + mm_items, + hf_inputs, + ) + else: + new_token_ids = self._find_and_replace_token_text_matches( + tokenizer, + new_token_ids, + all_prompt_repls, + mm_items, + hf_inputs, + ) + + all_placeholder_ranges = list( + self._extract_placeholder_ranges(all_prompt_repls, + new_token_ids)) + + mm_placeholders = { + modality: [ + PlaceholderRange(offset=item["offset"], length=item["length"]) + for item in items + ] + for modality, items in full_groupby(all_placeholder_ranges, + key=lambda x: x["modality"]) + } return MultiModalInputsV2( type="multimodal", - prompt=new_prompt, + prompt=_decode(tokenizer, new_token_ids), prompt_token_ids=new_token_ids, mm_kwargs=mm_kwargs, mm_placeholders=mm_placeholders, diff --git a/vllm/utils.py b/vllm/utils.py index 2bbdc8d1ebde8..15d4400b232f5 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -19,7 +19,8 @@ import warnings import weakref from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task -from collections.abc import Mapping +from collections import defaultdict +from collections.abc import Iterable, Mapping from functools import lru_cache, partial, wraps from platform import uname from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic, @@ -899,6 +900,23 @@ def flatten_2d_lists(lists: List[List[T]]) -> List[T]: return [item for sublist in lists for item in sublist] +_K = TypeVar("_K", bound=Hashable) +_V = TypeVar("_V") + + +def full_groupby(values: Iterable[_V], *, key: Callable[[_V], _K]): + """ + Unlike :class:`itertools.groupby`, groups are not broken by + non-contiguous data. + """ + groups: dict[_K, list[_V]] = defaultdict(list) + + for value in values: + groups[key(value)].append(value) + + return groups.items() + + # TODO: This function can be removed if transformer_modules classes are # serialized by value when communicating between processes def init_cached_hf_modules() -> None: