Skip to content

Commit

Permalink
[bug fix] Fix llava next feature size calculation. (vllm-project#6339)
Browse files Browse the repository at this point in the history
Signed-off-by: Xiaowei Jiang <[email protected]>
Signed-off-by: Alvant <[email protected]>
  • Loading branch information
xwjiang2010 authored and Alvant committed Oct 26, 2024
1 parent de6888b commit 1f8b098
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 9 deletions.
14 changes: 13 additions & 1 deletion tests/models/test_llava_next.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import List, Optional, Tuple

import pytest
from transformers import AutoTokenizer
from transformers import AutoConfig, AutoTokenizer

from vllm.model_executor.models.llava_next import (
get_llava_next_image_feature_size)
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs

Expand Down Expand Up @@ -120,3 +122,13 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
name_0="hf",
name_1="vllm",
)


@pytest.mark.parametrize("height_and_width_and_result", [(1669, 2560, 2144),
(183, 488, 776)])
def test_image_feature_size(height_and_width_and_result):
height, width, result = height_and_width_and_result
config = AutoConfig.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
assert get_llava_next_image_feature_size(config,
input_height=height,
input_width=width) == result
18 changes: 10 additions & 8 deletions vllm/model_executor/models/llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,19 +74,21 @@ def _get_llava_next_num_unpadded_features(
) -> Tuple[int, int]:
current_height = npatches * num_patch_height
current_width = npatches * num_patch_width
current_height = torch.tensor(current_height).to("cuda")
current_width = torch.tensor(current_width).to("cuda")

aspect_ratio: float = width / height
current_aspect_ratio: float = current_width / current_height
if aspect_ratio > current_aspect_ratio:
new_height = (height * current_width) // width
if new_height % 2 == 1:
new_height += 1
current_height = new_height
scale_factor = current_width / width
new_height = int(height * scale_factor)
padding = (current_height - new_height) // 2
current_height -= padding * 2
else:
new_width = (width * current_height) // height
if new_width % 2 == 1:
new_width += 1
current_width = new_width
scale_factor = current_height / height
new_width = int(width * scale_factor)
padding = (current_width - new_width) // 2
current_width -= padding * 2

unpadded_features = current_height * current_width
newline_features = current_height
Expand Down

0 comments on commit 1f8b098

Please sign in to comment.