Skip to content

Commit

Permalink
[Model] Remove hardcoded image tokens ids from Pixtral (vllm-project#…
Browse files Browse the repository at this point in the history
…11582)

Signed-off-by: Roger Wang <[email protected]>
  • Loading branch information
ywang96 authored and Ubuntu committed Jan 19, 2025
1 parent 3ca5715 commit 1429cf9
Showing 1 changed file with 13 additions and 14 deletions.
27 changes: 13 additions & 14 deletions vllm/model_executor/models/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,6 @@
except ImportError:
USE_XFORMERS_OPS = False

# These token ids cannot be retrieved from model config
# so we hardcode them here.
PIXTRAL_12B_IMAGE_BREAK_ID = 12
PIXTRAL_12B_IMAGE_END_ID = 13
PIXTRAL_LARGE_IMAGE_BREAK_ID = 14
PIXTRAL_LARGE_IMAGE_END_ID = 15


def get_max_pixtral_image_tokens(ctx: InputContext):
tokenizer = cached_get_tokenizer(
Expand Down Expand Up @@ -201,6 +194,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
if key in dataclass_fields
}

if not ("image_break_token_id" in vision_args
and "image_end_token_id" in vision_args):
raise ValueError(
"'image_break_token_id' and 'image_end_token_id' not found "
"in the vision_encoder arguments. Please download the latest "
"version of 'params.json' from the model repository.")

self.vision_args = VisionEncoderArgs(**vision_args)

# init MistralForCausalLM
Expand Down Expand Up @@ -240,9 +240,8 @@ def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:

# NOTE: Image embeddings are split into separate tensors for each image
# by the indices of `[IMG_END]` token.
image_end_condition = (image_tokens == PIXTRAL_12B_IMAGE_END_ID) | (
image_tokens == PIXTRAL_LARGE_IMAGE_END_ID)
split_indices = torch.where(image_end_condition)[0] + 1
image_end_mask = image_tokens == self.vision_args.image_end_token_id
split_indices = torch.where(image_end_mask)[0] + 1
if len(split_indices) <= 1:
# Do not split, return as tensor of shape [1, fs, hs]
return image_embeds.unsqueeze(0)
Expand All @@ -265,10 +264,8 @@ def get_input_embeddings(
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, [
self.vision_args.image_token_id,
PIXTRAL_12B_IMAGE_END_ID,
PIXTRAL_12B_IMAGE_BREAK_ID,
PIXTRAL_LARGE_IMAGE_BREAK_ID,
PIXTRAL_LARGE_IMAGE_END_ID,
self.vision_args.image_break_token_id,
self.vision_args.image_end_token_id,
])
return inputs_embeds

Expand Down Expand Up @@ -409,6 +406,8 @@ class VisionEncoderArgs:
num_attention_heads: int
rope_theta: float # for rope-2D
image_token_id: int
image_break_token_id: int
image_end_token_id: int
adapter_bias: bool = True


Expand Down

0 comments on commit 1429cf9

Please sign in to comment.