Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Doc][V1] Update model implementation guide for V1 support #11998

Merged
merged 5 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion docs/source/contributing/model/basic.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,17 @@ class MyModelForCausalLM(nn.Module):

### Computation Code

Rewrite the {meth}`~torch.nn.Module.forward` method of your model to remove any unnecessary code, such as training-specific code. Modify the input parameters to treat `input_ids` and `positions` as flattened tensors with a single batch size dimension, without a max-sequence length dimension.
- Add a `get_input_embeddings` method inside `MyModel` module that returns the text embeddings given `input_ids`. This is equivalent to directly calling the text embedding layer, but provides a unified interface in case `MyModel` is used within a composite multimodal model.

```python
class MyModel(nn.Module):
...

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
...
```

- Rewrite the {meth}`~torch.nn.Module.forward` method of your model to remove any unnecessary code, such as training-specific code. Modify the input parameters to treat `input_ids` and `positions` as flattened tensors with a single batch size dimension, without a max-sequence length dimension.

```python
def forward(
Expand Down
87 changes: 72 additions & 15 deletions docs/source/contributing/model/multimodal.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,78 @@ This document walks you through the steps to extend a basic model so that it acc
It is assumed that you have already implemented the model in vLLM according to [these steps](#new-model-basic).
Further update the model as follows:

- Implement the {class}`~vllm.model_executor.models.interfaces.SupportsMultiModal` interface.
- Reserve a keyword parameter in {meth}`~torch.nn.Module.forward` for each input tensor that corresponds to a multi-modal input, as shown in the following example:

```diff
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
+ pixel_values: torch.Tensor,
) -> SamplerOutput:
```

More conveniently, you can simply pass `**kwargs` to the {meth}`~torch.nn.Module.forward` method and retrieve the keyword parameters for multimodal inputs from it.

- Implement {meth}`~vllm.model_executor.models.interfaces.SupportsMultiModal.get_multimodal_embeddings` that returns the embeddings from running the multimodal inputs through the multimodal tokenizer of the model. Below we provide a boilerplate of a typical implementation pattern, but feel free to adjust it to your own needs.

```python
class YourModelForImage2Seq(nn.Module):
...

def _process_image_input(self, image_input: YourModelImageInputs) -> torch.Tensor:

assert self.vision_encoder is not None
image_features = self.vision_encoder(image_input)
return self.multi_modal_projector(image_features)

def get_multimodal_embeddings(self, **kwargs: object) -> Optional[NestedTensors]:

# Validate the multimodal input keyword arguments
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None

# Run multimodal inputs through encoder and projector
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
```

```{important}
The returned `multimodal_embeddings` must be either a **3D {class}`torch.Tensor`** of shape `(num_items, feature_size, hidden_size)`, or a **list / tuple of 2D {class}`torch.Tensor`'s** of shape `(feature_size, hidden_size)`, so that `multimodal_embeddings[i]` retrieves the embeddings generated from the `i`-th multimodal data item (e.g, image) of the request.
```

- Implement {meth}`~vllm.model_executor.models.interfaces.SupportsMultiModal.get_input_embeddings` to merge `multimodal_embeddings` with text embeddings from the `input_ids`. If input processing for the model is implemented correctly (see sections below), then you can leverage the utility function we provide to easily merge the embeddings.

```python
from .utils import merge_multimodal_embeddings

class YourModelForImage2Seq(nn.Module):
...

def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:

# `get_input_embeddings` should already be implemented for the language
# model as one of the requirements of basic vLLM model implementation.
inputs_embeds = self.language_model.get_input_embeddings(input_ids)

if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
placeholder_token_id=self.config.image_token_index)

return inputs_embeds
```

- Once the above steps are done, update the model class with the {class}`~vllm.model_executor.models.interfaces.SupportsMultiModal` interface.

```diff
+ from vllm.model_executor.models.interfaces import SupportsMultiModal
Expand All @@ -23,20 +94,6 @@ Further update the model as follows:
Check out [the HuggingFace Transformers documentation](https://huggingface.co/docs/transformers/model_doc/auto#multimodal) for some examples.
```

- If you haven't already done so, reserve a keyword parameter in {meth}`~torch.nn.Module.forward`
for each input tensor that corresponds to a multi-modal input, as shown in the following example:

```diff
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
+ pixel_values: torch.Tensor,
) -> SamplerOutput:
```

## 2. Specify processing information

Next, create a subclass of {class}`~vllm.multimodal.processing.BaseProcessingInfo`
Expand Down
Loading