Skip to content

Commit

Permalink
[Doc] Add docs for prompt replacement
Browse files Browse the repository at this point in the history
Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 committed Jan 22, 2025
1 parent 528dbca commit 4ae3aa0
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 15 deletions.
2 changes: 1 addition & 1 deletion vllm/model_executor/models/ultravox.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def get_replacement_ultravox(item_idx: int):
return [
PromptReplacement(
modality="audio",
target='<|audio|>',
target="<|audio|>",
replacement=get_replacement_ultravox,
)
]
Expand Down
92 changes: 78 additions & 14 deletions vllm/multimodal/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,41 +29,101 @@
logger = init_logger(__name__)

_S = TypeVar("_S", str, list[int])
_PromptSeq = Union[str, list[int]]

PromptSeq = Union[str, list[int]]
"""A token sequence (list of token IDs) or text."""


@dataclass
class PromptReplacementDetails:
full: _PromptSeq
"""Details about the replacement token sequence or text."""

full: PromptSeq
"""The full replacement."""

features: _PromptSeq
features: PromptSeq
"""
The part of the replacement that corresponds to placeholder feature tokens.
The part of the replacement that corresponds to feature placeholders;
this will be replaced by the output of the vision encoder during model
inference.
"""

@staticmethod
def from_seq(seq: _PromptSeq) -> "PromptReplacementDetails":
def from_seq(seq: PromptSeq) -> "PromptReplacementDetails":
return PromptReplacementDetails(full=seq, features=seq)


_PromptRepl = Union[_PromptSeq, PromptReplacementDetails]
PromptRepl = Union[PromptSeq, PromptReplacementDetails]
"""
The replacement token sequence or text.
If only part of the replacement corresponds to feature placeholders, you can
use :class:`PromptReplacementDetails` to specify which part.
"""


@dataclass
class PromptReplacement:
"""
Defines how to replace portions of an input prompt with placeholder tokens.
Example:
For each image, replace one ``<image>`` input placeholder in the prompt
with a number of ``<image>`` feature placeholders
equal to the feature size of the vision encoder:
.. code-block:: python
PromptReplacement(
modality="image",
target="<image>",
replacement="<image>" * image_feature_size,
)
As above, but further pad the feature placeholders with ``<image_bos>``
and `<image_eos>``, which are not supposed to be passed to the vision
encoder:
.. code-block:: python
PromptReplacement(
modality="image",
target="<image>",
replacement=PromptReplacementDetails(
full="".join([
"<image_bos>",
"<image>" * image_feature_size,
"<image_eos>",
]),
features="<image>" * image_feature_size,
),
)
To avoid unnecessary tokenization during prompt replacement,
we recommended passing token sequences instead of text:
.. code-block:: python
PromptReplacement(
modality="image",
target=[image_token_id],
replacement=PromptReplacementDetails(
full=([image_bos_id] + [image_token_id] * image_feature_size
+ [image_eos_id]),
features=[image_token_id] * image_feature_size,
),
)
"""

modality: str
"""The modality for which the replacement is made."""

target: _PromptSeq
target: PromptSeq
"""The token sequence (or text) to find and replace."""

replacement: Union[Callable[[int], _PromptRepl],
_PromptRepl] = field(repr=False)
replacement: Union[Callable[[int], PromptRepl],
PromptRepl] = field(repr=False)
"""
Given the index of the processed item within :attr:`modality`,
output the replacement token sequence (or text).
Expand Down Expand Up @@ -126,6 +186,10 @@ def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:

@dataclass
class _BoundPromptSequence:
"""
A :data:`_PromptSeq` bound to a tokenizer to automatically
convert between token sequence and text representations.
"""
tokenizer: AnyTokenizer = field(repr=False)

_text: Optional[str]
Expand All @@ -134,7 +198,7 @@ class _BoundPromptSequence:
@staticmethod
def from_seq(
tokenizer: AnyTokenizer,
seq: _PromptSeq,
seq: PromptSeq,
) -> "_BoundPromptSequence":
return _BoundPromptSequence(
tokenizer=tokenizer,
Expand Down Expand Up @@ -180,9 +244,9 @@ class BoundPromptReplacement:
tokenizer: AnyTokenizer = field(repr=False)
modality: str

_target: _PromptSeq
_replacement: Union[Callable[[int], _PromptRepl],
_PromptRepl] = field(repr=False)
_target: PromptSeq
_replacement: Union[Callable[[int], PromptRepl],
PromptRepl] = field(repr=False)

def __post_init__(self) -> None:
self._replacement_cache = dict[int, _BoundPromptReplacementGroup]()
Expand Down Expand Up @@ -350,7 +414,7 @@ def find_text_matches(


def _resolve_matches(
prompt: _PromptSeq,
prompt: PromptSeq,
mm_matches: Mapping[str, Sequence[_PromptReplacementMatch]],
) -> list[_PromptReplacementMatch]:
"""
Expand Down

0 comments on commit 4ae3aa0

Please sign in to comment.