Skip to content

Commit

Permalink
Support and test Mantis model
Browse files Browse the repository at this point in the history
Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 committed Nov 27, 2024
1 parent 8539008 commit 1e82a4a
Show file tree
Hide file tree
Showing 11 changed files with 129 additions and 27 deletions.
6 changes: 4 additions & 2 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,6 @@ steps:
commands:
- pytest -v -s models/decoder_only/language -m 'core_model or quant_model'
- pytest -v -s models/embedding/language -m core_model
- pytest -v -s models/embedding/vision_language -m core_model

- label: Language Models Test (Extended) # 50min
optional: true
Expand All @@ -346,7 +345,6 @@ steps:
commands:
- pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model'
- pytest -v -s models/embedding/language -m 'not core_model'
- pytest -v -s models/embedding/vision_language -m 'not core_model'

- label: Multi-Modal Models Test (Standard) # 26min
#mirror_hardwares: [amd]
Expand All @@ -357,8 +355,10 @@ steps:
- tests/models/embedding/vision_language
- tests/models/encoder_decoder/vision_language
commands:
- python -m pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
- pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model'
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'core_model or quant_model'
- pytest -v -s models/embedding/vision_language -m core_model
- pytest -v -s models/encoder_decoder/language -m core_model
- pytest -v -s models/encoder_decoder/vision_language -m core_model

Expand All @@ -371,11 +371,13 @@ steps:
- tests/models/embedding/vision_language
- tests/models/encoder_decoder/vision_language
commands:
- python -m pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
- pytest -v -s models/decoder_only/audio_language -m 'not core_model and not quant_model'
# HACK - run phi3v tests separately to sidestep this transformers bug
# https://github.com/huggingface/transformers/issues/34307
- pytest -v -s models/decoder_only/vision_language/test_phi3v.py
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'not core_model and not quant_model'
- pytest -v -s models/embedding/vision_language -m 'not core_model'
- pytest -v -s models/encoder_decoder/language -m 'not core_model'
- pytest -v -s models/encoder_decoder/vision_language -m 'not core_model'

Expand Down
17 changes: 17 additions & 0 deletions examples/offline_inference_vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,22 @@ def run_aria(question: str, modality: str):
return llm, prompt, stop_token_ids


# Mantis
def run_mantis(question: str, modality: str):
assert modality == "image"

llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n' # noqa: E501
prompt = llama3_template.format(f"{question}\n<image>")

llm = LLM(
model="TIGER-Lab/Mantis-8B-siglip-llama3",
max_model_len=4096,
hf_overrides={"architectures": ["MantisForConditionalGeneration"]},
)
stop_token_ids = [128009]
return llm, prompt, stop_token_ids


model_example_map = {
"llava": run_llava,
"llava-next": run_llava_next,
Expand All @@ -441,6 +457,7 @@ def run_aria(question: str, modality: str):
"glm4v": run_glm4v,
"idefics3": run_idefics3,
"aria": run_aria,
"mantis": run_mantis,
}


Expand Down
3 changes: 0 additions & 3 deletions requirements-test.in
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,6 @@ mistral_common[opencv] >= 1.4.4 # required for pixtral test
datamodel_code_generator # required for minicpm3 test
lm-eval[api]==0.4.4 # required for model evaluation test

# TODO: Add this after fully implementing llava(mantis)
# git+https://github.com/TIGER-AI-Lab/Mantis.git # required for llava(mantis) test

# quantization
bitsandbytes>=0.44.0
buildkite-test-collector==0.1.9
Expand Down
1 change: 0 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,6 @@ def __init__(
dtype: str = "half",
*,
model_kwargs: Optional[Dict[str, Any]] = None,
is_embedding_model: bool = False,
is_sentence_transformer: bool = False,
is_cross_encoder: bool = False,
skip_tokenizer_init: bool = False,
Expand Down
28 changes: 21 additions & 7 deletions tests/models/decoder_only/vision_language/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
"dtype": "half",
"max_tokens": 5,
"tensor_parallel_size": 2,
"model_kwargs": {"device_map": "auto"},
"hf_model_kwargs": {"device_map": "auto"},
"image_size_factors": [(.25, 0.5, 1.0)],
"distributed_executor_backend": (
"ray",
Expand Down Expand Up @@ -108,7 +108,7 @@
"cherry_blossom": "What is in the picture?",
}),
auto_cls=AutoModelForVision2Seq,
postprocess_inputs=model_utils.get_key_type_post_processor(
postprocess_inputs=model_utils.cast_dtype_post_processor(
"pixel_values"
),
vllm_output_post_proc=model_utils.paligemma_vllm_to_hf_output,
Expand Down Expand Up @@ -148,7 +148,7 @@
prompt_formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:",
max_model_len=4096,
auto_cls=AutoModelForVision2Seq,
postprocess_inputs=model_utils.get_key_type_post_processor(
postprocess_inputs=model_utils.cast_dtype_post_processor(
"pixel_values"
),
# For chameleon, we only compare the sequences
Expand Down Expand Up @@ -264,7 +264,7 @@
prompt_formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
num_video_frames=16,
max_model_len=16384,
postprocess_inputs=model_utils.get_key_type_post_processor(
postprocess_inputs=model_utils.cast_dtype_post_processor(
"pixel_values_videos"
),
auto_cls=AutoModelForVision2Seq,
Expand Down Expand Up @@ -295,6 +295,20 @@
)
],
),
"mantis": VLMTestInfo(
models=["TIGER-Lab/Mantis-8B-siglip-llama3"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
max_model_len=4096,
postprocess_inputs=model_utils.cast_dtype_post_processor(
"pixel_values"
),
vllm_runner_kwargs={"hf_overrides": {"architectures": ["MantisForConditionalGeneration"]}}, # noqa: E501
get_stop_token_ids=lambda tok: [128009],
auto_cls=AutoModelForVision2Seq,
vllm_output_post_proc=model_utils.mantis_vllm_to_hf_output,
patch_hf_runner=model_utils.mantis_patch_hf_runner,
),
"minicpmv": VLMTestInfo(
models=["openbmb/MiniCPM-Llama3-V-2_5"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
Expand All @@ -318,7 +332,7 @@
# max_num_seqs=2,
# task="generate",
# # use eager mode for hf runner since phi3v didn't work with flash_attn
# model_kwargs={"_attn_implementation": "eager"},
# hf_model_kwargs={"_attn_implementation": "eager"},
# use_tokenizer_eos=True,
# vllm_output_post_proc=model_utils.phi3v_vllm_to_hf_output,
# num_logprobs=10,
Expand Down Expand Up @@ -349,7 +363,7 @@
prompt_formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:",
max_model_len=4096,
auto_cls=AutoModelForVision2Seq,
postprocess_inputs=model_utils.get_key_type_post_processor(
postprocess_inputs=model_utils.cast_dtype_post_processor(
"pixel_values"
),
vllm_output_post_proc = lambda vllm_output, model: vllm_output[:2],
Expand Down Expand Up @@ -418,7 +432,7 @@
test_type=VLMTestType.CUSTOM_INPUTS,
max_model_len=16384,
max_num_seqs=2,
postprocess_inputs=model_utils.get_key_type_post_processor(
postprocess_inputs=model_utils.cast_dtype_post_processor(
"pixel_values"
),
auto_cls=AutoModelForVision2Seq,
Expand Down
16 changes: 11 additions & 5 deletions tests/models/decoder_only/vision_language/vlm_utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch
from PIL.Image import Image
from transformers import AutoTokenizer, BatchEncoding
from transformers import AutoTokenizer, BatchEncoding, PreTrainedTokenizerBase
from transformers.models.auto.auto_factory import _BaseAutoModelClass

from .....conftest import HfRunner, VllmRunner
Expand All @@ -28,9 +28,11 @@ def run_test(
use_tokenizer_eos: bool,
postprocess_inputs: Callable[[BatchEncoding], BatchEncoding],
comparator: Callable[..., None],
get_stop_token_ids: Optional[Callable[[AutoTokenizer], List[int]]],
get_stop_token_ids: Optional[Callable[[PreTrainedTokenizerBase],
List[int]]],
limit_mm_per_prompt: Dict[str, int],
model_kwargs: Optional[Dict[str, Any]],
vllm_runner_kwargs: Optional[Dict[str, Any]],
hf_model_kwargs: Optional[Dict[str, Any]],
patch_hf_runner: Optional[Callable[[HfRunner], HfRunner]],
task: str = "auto",
runner_mm_key: str = "images",
Expand All @@ -54,6 +56,9 @@ def run_test(
if get_stop_token_ids is not None:
vllm_kwargs["stop_token_ids"] = get_stop_token_ids(tokenizer)

if vllm_runner_kwargs is None:
vllm_runner_kwargs = {}

with vllm_runner(model,
max_model_len=max_model_len,
max_num_seqs=max_num_seqs,
Expand All @@ -62,7 +67,8 @@ def run_test(
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=enforce_eager,
task=task) as vllm_model:
task=task,
**vllm_runner_kwargs) as vllm_model:
for prompts, media in vllm_inputs:
vllm_kwargs[runner_mm_key] = media
vllm_output = vllm_model.generate_greedy_logprobs(
Expand All @@ -73,7 +79,7 @@ def run_test(
dtype=dtype,
auto_cls=auto_cls,
postprocess_inputs=postprocess_inputs,
model_kwargs=model_kwargs)
model_kwargs=hf_model_kwargs)

# Some models need to patch things like the model processor, e.g., internvl
if patch_hf_runner is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,16 @@ def llava_onevision_vllm_to_hf_output(vllm_output: RunnerOutput,
return hf_output_ids, hf_output_str, out_logprobs


def mantis_vllm_to_hf_output(vllm_output: RunnerOutput,
model: str) -> RunnerOutput:
"""Sanitize vllm output [mantis] to compare with hf output."""
output_ids, output_str, out_logprobs = vllm_output

hf_output_str = output_str + "<|eot_id|>"

return output_ids, hf_output_str, out_logprobs


def phi3v_vllm_to_hf_output(vllm_output: RunnerOutput,
model: str) -> RunnerOutput:
"""Sanitize vllm output [phi3v] to be comparable with hf output."""
Expand Down Expand Up @@ -184,7 +194,7 @@ def get_llava_embeddings(image_assets: _ImageAssets):


####### postprocessors to run on HF BatchEncoding
def get_key_type_post_processor(
def cast_dtype_post_processor(
hf_inp_key: str) -> Callable[[BatchEncoding, str], BatchEncoding]:
"""Gets a handle to a post processor which converts a given key into a
target data type."""
Expand Down Expand Up @@ -407,3 +417,26 @@ def _internvl_generate(
)

return outputs


def mantis_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
from mantis.models.mllava import MLlavaProcessor

hf_model.processor = MLlavaProcessor.from_pretrained(hf_model.model_name)

orig_generate = hf_model.model.generate
tokenizer = hf_model.processor.tokenizer

def _generate(self, *args, **kwargs):
return orig_generate(
*args,
**kwargs,
eos_token_id=[
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>"),
],
)

hf_model.model.generate = types.MethodType(_generate, hf_model.model)

return hf_model
16 changes: 10 additions & 6 deletions tests/models/decoder_only/vision_language/vlm_utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import torch
from PIL.Image import Image
from pytest import MarkDecorator
from transformers import AutoModelForCausalLM, AutoTokenizer, BatchEncoding
from transformers import (AutoModelForCausalLM, BatchEncoding,
PreTrainedTokenizerBase)
from transformers.models.auto.auto_factory import _BaseAutoModelClass

from vllm.sequence import SampleLogprobs
Expand Down Expand Up @@ -66,7 +67,7 @@ class ImageSizeWrapper(NamedTuple):
class VLMTestInfo(NamedTuple):
"""Holds the configuration for 1+ tests for one model architecture."""

models: Union[List[str]]
models: List[str]
test_type: Union[VLMTestType, Iterable[VLMTestType]]

# Should be None only if this is a CUSTOM_INPUTS test
Expand Down Expand Up @@ -94,13 +95,15 @@ class VLMTestInfo(NamedTuple):
max_num_seqs: int = 256
task: str = "auto"
tensor_parallel_size: int = 1
vllm_runner_kwargs: Optional[Dict[str, Any]] = None

# Optional callable which gets a list of token IDs from the model tokenizer
get_stop_token_ids: Optional[Callable[[AutoTokenizer], List[int]]] = None
get_stop_token_ids: Optional[Callable[[PreTrainedTokenizerBase],
List[int]]] = None

# Exposed options for HF runner
model_kwargs: Optional[Dict[str, Any]] = None
# Indicates we should explicitly pass the EOS from the tokeniezr
hf_model_kwargs: Optional[Dict[str, Any]] = None
# Indicates we should explicitly pass the EOS from the tokenizer
use_tokenizer_eos: bool = False
auto_cls: Type[_BaseAutoModelClass] = AutoModelForCausalLM
# Callable to pass to the HF runner to run on inputs; for now, we also pass
Expand Down Expand Up @@ -159,14 +162,15 @@ def get_non_parametrized_runner_kwargs(self):
"max_num_seqs": self.max_num_seqs,
"task": self.task,
"tensor_parallel_size": self.tensor_parallel_size,
"vllm_runner_kwargs": self.vllm_runner_kwargs,
"hf_output_post_proc": self.hf_output_post_proc,
"vllm_output_post_proc": self.vllm_output_post_proc,
"auto_cls": self.auto_cls,
"use_tokenizer_eos": self.use_tokenizer_eos,
"postprocess_inputs": self.postprocess_inputs,
"comparator": self.comparator,
"get_stop_token_ids": self.get_stop_token_ids,
"model_kwargs": self.model_kwargs,
"hf_model_kwargs": self.hf_model_kwargs,
"patch_hf_runner": self.patch_hf_runner,
}

Expand Down
1 change: 1 addition & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ class _HfExamplesInfo:
"LlavaNextForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-v1.6-mistral-7b-hf"), # noqa: E501
"LlavaNextVideoForConditionalGeneration": _HfExamplesInfo("llava-hf/LLaVA-NeXT-Video-7B-hf"), # noqa: E501
"LlavaOnevisionForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-onevision-qwen2-0.5b-ov-hf"), # noqa: E501
"MantisForConditionalGeneration": _HfExamplesInfo("TIGER-Lab/Mantis-8B-siglip-llama3"), # noqa: E501
"MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-Llama3-V-2_5",
trust_remote_code=True),
"MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924",
Expand Down
32 changes: 30 additions & 2 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from PIL.Image import Image
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
PixtralVisionConfig, PretrainedConfig,
SiglipVisionConfig)
ProcessorMixin, SiglipVisionConfig)

from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
Expand All @@ -21,7 +21,7 @@
from vllm.multimodal.processing import (InputProcessingContext,
ModalityProcessingMetadata,
MultiModalProcessingMetadata,
PromptReplacement)
MultiModalProcessor, PromptReplacement)
from vllm.sequence import IntermediateTensors

from .clip import CLIPVisionModel, get_max_clip_image_tokens
Expand Down Expand Up @@ -499,3 +499,31 @@ def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)


class MantisProcessor(MultiModalProcessor):

def _get_hf_processor(self) -> ProcessorMixin:
try:
from mantis.models.mllava import MLlavaProcessor
except ModuleNotFoundError as exc:
raise ModuleNotFoundError(
"You need to `pip install "
"git+https://github.com/TIGER-AI-Lab/Mantis.git` "
"to use this model") from exc

processor = MLlavaProcessor.from_pretrained(
self.ctx.model_config.tokenizer)
assert isinstance(processor, ProcessorMixin)
return processor


# To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
@MULTIMODAL_REGISTRY.register_processor(lambda ctx: MantisProcessor(
ctx=ctx,
metadata=create_metadata_for_llava(ctx),
))
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
pass
1 change: 1 addition & 0 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
"LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), # noqa: E501
"LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"), # noqa: E501
"MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"), # noqa: E501
"MiniCPMV": ("minicpmv", "MiniCPMV"),
"MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
"NVLM_D": ("nvlm_d", "NVLM_D_Model"),
Expand Down

0 comments on commit 1e82a4a

Please sign in to comment.