Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 committed Nov 22, 2024
1 parent 6ab14df commit 1122c3f
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 79 deletions.
4 changes: 2 additions & 2 deletions tests/multimodal/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ def test_iter_token_runs(token_ids, expected):
print("result:", result)

# Manually constructed results
assert result == expected
assert [item._asdict() for item in result] == expected

# Invariants
assert sum(run_info["length"] for run_info in result) == len(token_ids)
assert sum(run_info.length for run_info in result) == len(token_ids)


# yapf: disable
Expand Down
181 changes: 104 additions & 77 deletions vllm/multimodal/processing.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import re
from abc import ABC, abstractmethod
from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence
from dataclasses import dataclass
from functools import lru_cache
from itertools import groupby
from typing import (Any, Callable, Generic, Iterable, Mapping, NamedTuple,
Optional, Sequence, TypeVar, Union)
from typing import Any, Generic, NamedTuple, Optional, Protocol, TypeVar, Union

import numpy as np
from transformers import BatchFeature
Expand All @@ -18,72 +18,21 @@
MultiModalInputsV2, MultiModalKwargs, PlaceholderRange,
VideoItem)


def _encode(
tokenizer: AnyTokenizer,
text: str,
*,
add_special_tokens: bool = False,
) -> list[int]:
"""
Backend-agnostic equivalent of HF's
:code:`tokenizer.encode(text, add_special_tokens=...)`.
"""
if isinstance(tokenizer, MistralTokenizer):
return tokenizer.tokenizer.encode(text,
bos=add_special_tokens,
eos=add_special_tokens)

return tokenizer.encode(text, add_special_tokens=add_special_tokens)


@lru_cache(maxsize=2048)
def _cached_encode(
tokenizer: AnyTokenizer,
text: str,
*,
add_special_tokens: bool = False,
) -> list[int]:
return _encode(tokenizer, text, add_special_tokens=add_special_tokens)


def _decode(
tokenizer: AnyTokenizer,
token_ids: list[int],
*,
skip_special_tokens: bool = False,
) -> str:
"""
Backend-agnostic equivalent of HF's
:code:`tokenizer.decode(token_ids, skip_special_tokens=...)`.
"""
return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)


@lru_cache(maxsize=2048)
def _cached_decode(
tokenizer: AnyTokenizer,
token_ids: tuple[int, ...],
*,
skip_special_tokens: bool = False,
) -> str:
return _decode(tokenizer,
list(token_ids),
skip_special_tokens=skip_special_tokens)


PromptSegment: TypeAlias = Union[str, list[int]]


def bind_segment(
prompt_segment: PromptSegment,
segment: PromptSegment,
tokenizer: AnyTokenizer,
) -> "_BoundPromptSegment":
"""
Bind a text or token prompt to a tokenizer so that it can be
lazily converted into the other format on demand.
"""
return _BoundPromptSegment(
tokenizer=tokenizer,
_text=prompt_segment if isinstance(prompt_segment, str) else None,
_token_ids=prompt_segment
if isinstance(prompt_segment, list) else None,
_text=segment if isinstance(segment, str) else None,
_token_ids=segment if isinstance(segment, list) else None,
)


Expand Down Expand Up @@ -163,6 +112,78 @@ class MultiModalProcessingMetadataBuiltins(TypedDict, total=False):
"""


def _encode(
tokenizer: AnyTokenizer,
text: str,
*,
add_special_tokens: bool = False,
) -> list[int]:
"""
Backend-agnostic equivalent of HF's
:code:`tokenizer.encode(text, add_special_tokens=...)`.
"""
if isinstance(tokenizer, MistralTokenizer):
return tokenizer.tokenizer.encode(text,
bos=add_special_tokens,
eos=add_special_tokens)

return tokenizer.encode(text, add_special_tokens=add_special_tokens)


@lru_cache(maxsize=2048)
def _cached_encode(
tokenizer: AnyTokenizer,
text: str,
*,
add_special_tokens: bool = False,
) -> list[int]:
return _encode(tokenizer, text, add_special_tokens=add_special_tokens)


def _decode(
tokenizer: AnyTokenizer,
token_ids: list[int],
*,
skip_special_tokens: bool = False,
) -> str:
"""
Backend-agnostic equivalent of HF's
:code:`tokenizer.decode(token_ids, skip_special_tokens=...)`.
"""
return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)


@lru_cache(maxsize=2048)
def _cached_decode(
tokenizer: AnyTokenizer,
token_ids: tuple[int, ...],
*,
skip_special_tokens: bool = False,
) -> str:
return _decode(tokenizer,
list(token_ids),
skip_special_tokens=skip_special_tokens)


class _HasModalityAttr(Protocol):
modality: str


class _HasModalityProp(Protocol):

@property
def modality(self) -> str:
...


_M = TypeVar("_M", bound=Union[_HasModalityAttr, _HasModalityProp])


def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
"""Convenience function to apply :func:`full_groupby` based on modality."""
return full_groupby(values, key=lambda x: x.modality)


@dataclass
class _BoundPromptSegment:
tokenizer: AnyTokenizer
Expand Down Expand Up @@ -237,7 +258,7 @@ def to_multi_format(data: MultiModalDataDict) -> dict[str, list[Any]]:
return multi_data


class _TokenRun(TypedDict):
class _TokenRun(NamedTuple):
token_id: int

start_idx: int
Expand All @@ -257,19 +278,20 @@ def iter_token_runs(token_ids: list[int]) -> Iterable[_TokenRun]:
start_idx += length


class _BoundPlaceholderRange(TypedDict):
class _BoundPlaceholderRange(NamedTuple):
modality: str
offset: int
length: int


def iter_placeholders(
prompt_repls: Sequence[_BoundPromptReplacement[Any]],
new_token_ids: list[int],
token_ids: list[int],
*,
min_placeholder_count: int,
) -> Iterable[_BoundPlaceholderRange]:
repls_by_modality = full_groupby(prompt_repls, key=lambda x: x.modality)
"""Yield each set of placeholder tokens found in :code:`token_ids`."""
repls_by_modality = full_groupby_modality(prompt_repls)

placeholder_ids_by_modality = {
modality: {
Expand All @@ -280,15 +302,15 @@ def iter_placeholders(
for modality, repls in repls_by_modality
}

for run_info in iter_token_runs(new_token_ids):
if run_info["length"] > min_placeholder_count:
for run_info in iter_token_runs(token_ids):
if run_info.length > min_placeholder_count:
for (modality,
placeholder_ids) in placeholder_ids_by_modality.items():
if run_info["token_id"] in placeholder_ids:
if run_info.token_id in placeholder_ids:
yield _BoundPlaceholderRange(
modality=modality,
offset=run_info["start_idx"],
length=run_info["length"],
offset=run_info.start_idx,
length=run_info.length,
)


Expand Down Expand Up @@ -398,6 +420,7 @@ def find_token_matches(
prompt: list[int],
prompt_repls: Sequence[_BoundPromptReplacement[_T]],
) -> list[_PromptReplacementTokenMatch[_T]]:
"""Return each target of :code:`prompt_repls` found in :code:`prompt`."""
return [
_PromptReplacementTokenMatch(prompt_repl, match)
for prompt_repl in prompt_repls
Expand All @@ -409,17 +432,22 @@ def find_text_matches(
prompt: str,
prompt_repls: Sequence[_BoundPromptReplacement[_T]],
) -> list[_PromptReplacementTextMatch[_T]]:
"""Return each target of :code:`prompt_repls` found in :code:`prompt`."""
return [
_PromptReplacementTextMatch(prompt_repl, match)
for prompt_repl in prompt_repls
for match in re.finditer(re.escape(prompt_repl.target.text), prompt)
]


def unique_sort_matches(
def resolve_matches(
prompt: _S,
matches: Sequence[_PromptReplacementMatch[_T, _S]],
) -> list[_PromptReplacementMatch[_T, _S]]:
"""
Resolve :code:`matches` to ensure that there are no overlapping matches,
and sort them such that earlier matches take priority over later ones.
"""
num_matches_by_idx = np.zeros(len(prompt), dtype=int)
for match in matches:
num_matches_by_idx[match.start_idx:match.end_idx] += 1
Expand All @@ -443,8 +471,7 @@ def _replace_matches(
prev_end_idx = 0
next_idx_by_modality = {modality: 0 for modality in mm_items_by_modality}

# Earlier matches take priority over later ones
for match in unique_sort_matches(prompt, matches):
for match in resolve_matches(prompt, matches):
modality = match.modality
mm_items = mm_items_by_modality[modality]

Expand All @@ -471,6 +498,7 @@ def replace_token_matches(
mm_items_by_modality: Mapping[str, list[_T]],
hf_inputs: BatchFeature,
) -> list[int]:
"""Apply :code:`prompt_repls` to :code:`prompt`."""
if not matches:
return prompt

Expand All @@ -490,6 +518,7 @@ def replace_text_matches(
mm_items_by_modality: Mapping[str, list[_T]],
hf_inputs: BatchFeature,
) -> str:
"""Apply :code:`prompt_repls` to :code:`prompt`."""
if not matches:
return prompt

Expand Down Expand Up @@ -581,8 +610,7 @@ def _apply_prompt_replacements(

if all(
len(matches) >= len(mm_data[modality])
for modality, matches in full_groupby(token_matches,
key=lambda x: x.modality)
for modality, matches in full_groupby_modality(token_matches)
): # yapf: disable
token_ids = replace_token_matches(
token_ids,
Expand Down Expand Up @@ -647,11 +675,10 @@ def apply(

mm_placeholders = {
modality: [
PlaceholderRange(offset=item["offset"], length=item["length"])
PlaceholderRange(offset=item.offset, length=item.length)
for item in items
]
for modality, items in full_groupby(all_placeholders,
key=lambda x: x["modality"])
for modality, items in full_groupby_modality(all_placeholders)
}

return MultiModalInputsV2(
Expand Down

0 comments on commit 1122c3f

Please sign in to comment.