From 1c6f7d844241206f6125b715a0dcb3e2e9f7b37c Mon Sep 17 00:00:00 2001 From: hzh Date: Wed, 22 Jan 2025 13:51:31 +0000 Subject: [PATCH] format Signed-off-by: hzh --- vllm/model_executor/models/minicpmv.py | 1040 ++++++++++-------------- 1 file changed, 440 insertions(+), 600 deletions(-) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 89ab5d02995be..bdcacfa19c532 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -23,23 +23,20 @@ import math import re from functools import cached_property, partial -from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional, - Set, Tuple, TypedDict, Union, Dict) +from itertools import accumulate +from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping, + Optional, Set, Tuple, TypedDict, Union) +import numpy as np import torch import torch.types from PIL import Image from torch import nn -import numpy as np from transformers import BatchFeature, PretrainedConfig +from transformers.cache_utils import DynamicCache, EncoderDecoderCache from transformers.modeling_outputs import BaseModelOutputWithPast -from transformers.cache_utils import EncoderDecoderCache, DynamicCache -from transformers.models.whisper.modeling_whisper import (ACT2FN, - WHISPER_ATTENTION_CLASSES, - WhisperConfig, - WhisperEncoder) -from typing_extensions import NotRequired -from itertools import accumulate +from transformers.models.whisper.modeling_whisper import ( + ACT2FN, WHISPER_ATTENTION_CLASSES, WhisperConfig, WhisperEncoder) from vllm.attention import AttentionMetadata from vllm.config import VllmConfig @@ -48,33 +45,27 @@ get_2d_sincos_pos_embed) from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader.utils import set_default_torch_dtype -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.minicpm import MiniCPMForCausalLM from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs -from vllm.multimodal.inputs import (MultiModalFieldConfig, - MultiModalDataDict, - PlaceholderRange, - MultiModalInputsV2) -from vllm.multimodal.parse import (ImageSize, - ImageItem, - VideoItem, - ModalityData, - ModalityDataItems, - MultiModalDataItems, - MultiModalDataParser) -from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalInputsV2, PlaceholderRange) +from vllm.multimodal.parse import (ImageItem, ImageSize, ModalityData, + ModalityDataItems, MultiModalDataItems, + MultiModalDataParser, VideoItem) from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, - PromptReplacement) -from vllm.sequence import IntermediateTensors, SequenceData + BaseProcessingInfo, PromptReplacement) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.sequence import IntermediateTensors from .idefics2_vision_model import Idefics2VisionTransformer from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP -from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix +from .utils import AutoWeightsLoader, maybe_prefix + +CPU_DEVICE = torch.device("cpu") RawImageType = Union[Image.Image, torch.Tensor] @@ -159,13 +150,14 @@ class MiniCPMVAudioEmbeddingInputs(TypedDict): class MiniCPMVEmbeddingItems(ModalityDataItems[dict[str, torch.Tensor], - dict[str, torch.Tensor]]): + dict[str, torch.Tensor]]): + def __init__(self, data: Dict, modality: str) -> None: super().__init__(data, modality) def get_processor_data(self) -> Mapping[str, object]: return self.data - + def get_passthrough_data(self) -> Mapping[str, object]: return {} @@ -180,50 +172,53 @@ def get(self, index: int) -> Dict[str, torch.Tensor]: class MiniCPMVImageEmbeddingItems(MiniCPMVEmbeddingItems): + def __init__(self, data: Dict) -> None: super().__init__(data, "image") image_embeds = self.data.get("image_embeds", None) image_sizes = self.data.get("image_sizes", None) if image_embeds is None: - raise ValueError(f"In correct type of image_embeds", - f"Got type: None") + raise ValueError("In correct type of image_embeds", + "Got type: None") if not isinstance(image_embeds[0], torch.Tensor): - raise ValueError(f"In correct type of image_embeds", + raise ValueError("In correct type of image_embeds", f"Got type: {type(image_embeds[0])}") if image_sizes is None: - raise ValueError(f"In correct type of image_sizes", - f"Got type: None." - "If you're using `image_size_list`, please rename it to `image_sizes`") + raise ValueError( + "In correct type of image_sizes", "Got type: None." + "If you're using `image_size_list`, " + "please rename it to `image_sizes`") if len(image_embeds[0].shape) == 2: image_embeds = [image_embeds] image_sizes = [image_sizes] self.data["image_embeds"] = image_embeds self.data["image_sizes"] = image_sizes - + def get_image_size(self, index: int) -> ImageSize: image_size = self.data["image_sizes"][index] return ImageSize(width=image_size[0], height=image_size[1]) - + class MiniCPMVVideoEmbeddingItems(MiniCPMVEmbeddingItems): + def __init__(self, data: Dict) -> None: super().__init__(data, "video") video_embeds = self.data.get("video_embeds", None) image_sizes = self.data.get("image_sizes", None) num_frames = self.data.get("num_frames", None) if video_embeds is None: - raise ValueError(f"In correct type of video_embeds", - f"Got type: None") + raise ValueError("In correct type of video_embeds", + "Got type: None") if not isinstance(video_embeds[0], torch.Tensor): - raise ValueError(f"In correct type of video_embeds", + raise ValueError("In correct type of video_embeds", f"Got type: {type(video_embeds[0])}") if image_sizes is None: - raise ValueError(f"In correct type of image_sizes", - f"Got type: None." - "If you're using `image_size_list`, please rename it to `image_sizes`") + raise ValueError( + "In correct type of image_sizes", "Got type: None." + "If you're using `image_size_list`, " + "please rename it to `image_sizes`") if num_frames is None: - raise ValueError(f"In correct type of numframes", - f"Got type: None") + raise ValueError("In correct type of numframes", "Got type: None") if len(video_embeds[0].shape) == 2: video_embeds = [video_embeds] image_sizes = [image_sizes] @@ -231,24 +226,25 @@ def __init__(self, data: Dict) -> None: self.data["video_embeds"] = video_embeds self.data["image_sizes"] = image_sizes self.data["num_frames"] = num_frames - + def get_frame_size(self, index: int) -> ImageSize: frame_size = self.data["image_sizes"][index] return ImageSize(width=frame_size[0], height=frame_size[1]) def get_num_frames(self, index: int) -> int: return self.data["num_frames"][index] - + class MiniCPMVAudioEmbeddingItems(MiniCPMVEmbeddingItems): + def __init__(self, data: Dict) -> None: super().__init__(data, "audio") audio_embeds = self.data.get("audio_embeds", None) if audio_embeds is None: - raise ValueError(f"In correct type of video_embeds", - f"Got type: None") + raise ValueError("In correct type of video_embeds", + "Got type: None") self.data["audio_embeds"] = audio_embeds - + def get(self, index: int) -> object: return self.data["audio_embeds"][index] @@ -398,7 +394,7 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo): def get_hf_config(self): return self.ctx.get_hf_config() - + def get_hf_processor( self, **kwargs: object, @@ -407,13 +403,11 @@ def get_hf_processor( # image_processor = hf_processor.image_processor return hf_processor - def get_image_processor( - self, - ): + def get_image_processor(self, ): hf_processor = self.get_hf_processor() image_processor = hf_processor.image_processor # type: ignore return image_processor - + def get_model_version(self): return get_version_by_config(self.get_hf_config()) @@ -432,11 +426,9 @@ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None, "video": None} else: return {"image": None} - + def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: - mm_max_tokens = { - "image": self.get_max_image_tokens() - } + mm_max_tokens = {"image": self.get_max_image_tokens()} if self.get_model_version() == (2, "6O"): mm_max_tokens["audio"] = self.get_max_audio_tokens() if self.get_model_version() in [(2, 6), (2, "6O")]: @@ -445,19 +437,22 @@ def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: def get_max_video_frame_tokens(self) -> int: frame_size = self.get_video_frame_size_with_most_features() - return self.get_num_image_tokens(frame_size, self.get_video_max_slice_num()) - + return self.get_num_image_tokens(frame_size, + self.get_video_max_slice_num()) + def get_max_video_tokens(self, seq_len: int) -> int: - return self.get_max_video_frame_tokens() * self.get_num_frames_with_most_features(seq_len) - + return self.get_max_video_frame_tokens( + ) * self.get_num_frames_with_most_features(seq_len) + def get_max_audio_tokens(self) -> int: - return self.get_max_audio_tokens_per_chunk() * self.get_max_audio_chunks_with_most_features() + return self.get_max_audio_tokens_per_chunk( + ) * self.get_max_audio_chunks_with_most_features() def get_slice_query_num(self) -> int: hf_config = self.get_hf_config() query_num = getattr(hf_config, "query_num", 64) return query_num - + def get_max_slice_num(self) -> int: hf_config = self.get_hf_config() max_slice_num = getattr(hf_config, "max_slice_num", 9) @@ -465,20 +460,25 @@ def get_max_slice_num(self) -> int: def get_sliced_grid(self, image_size, max_slice_num) -> Tuple[int, int]: if self.get_model_version() in [(2, 6), (2, "6O")]: - slice_grid = self.get_image_processor().get_sliced_grid(image_size, max_slice_num) + slice_grid = self.get_image_processor().get_sliced_grid( + image_size, max_slice_num) else: slice_grid = self.get_image_processor().get_sliced_grid(image_size) return slice_grid - def get_num_image_tokens(self, image_size: ImageSize, max_slice_num: int) -> int: + def get_num_image_tokens(self, image_size: ImageSize, + max_slice_num: int) -> int: slice_grid = self.get_sliced_grid(image_size, max_slice_num) - num_tokens = self.get_slice_query_num() + 2 # ( * query_num) + num_tokens = self.get_slice_query_num( + ) + 2 # ( * query_num) if slice_grid is not None: if self.get_model_version() in [(2, 6), (2, "6O")]: - num_additional_tokens = 0 # ( * query_num) + num_additional_tokens = 0 else: - num_additional_tokens = 2 # ( * query_num) - num_tokens += (self.get_slice_query_num() + 2) * slice_grid[0] * slice_grid[1] \ + # ( * query_num) + num_additional_tokens = 2 + num_tokens += ((self.get_slice_query_num() + 2) \ + * slice_grid[0] * slice_grid[1]) \ + slice_grid[1] - 1 + num_additional_tokens return num_tokens @@ -489,40 +489,39 @@ def get_image_slice_nums(self, image_size: torch.Tensor, max_slice_nums): def get_max_image_tokens(self) -> int: image_size = self.get_image_size_with_most_features() return self.get_num_image_tokens(image_size, self.get_max_slice_num()) - + def get_image_size_with_most_features(self) -> ImageSize: - # Result in the max possible feature size (h:w = 9:1) - return self.get_defaul_image_sizes( - self.get_max_slice_num() - ) + # Result in the max possible feature size (h:w = 9:1) + return self.get_defaul_image_sizes(self.get_max_slice_num()) def get_video_max_slice_num(self) -> int: return 1 def get_video_frame_size_with_most_features(self) -> ImageSize: - return self.get_defaul_image_sizes( - self.get_video_max_slice_num() - ) + return self.get_defaul_image_sizes(self.get_video_max_slice_num()) def get_max_video_frames(self, max_tokens: int) -> int: num_frame_tokens = self.get_max_video_frame_tokens() num_frames = max_tokens // num_frame_tokens return num_frames - + def get_num_frames_with_most_features(self, seq_len: int) -> int: mm_config = self.ctx.get_mm_config() max_images = mm_config.limit_per_prompt.get("image", 1) max_videos = mm_config.limit_per_prompt.get("video", 1) max_audios = mm_config.limit_per_prompt.get("audio", 1) - # count tokens which are not in get_max_image_tokens - max_image_tokens = self.get_max_image_tokens() * max_images + 4 * max_images + # count tokens + # which are not in get_max_image_tokens + max_image_tokens = self.get_max_image_tokens( + ) * max_images + 4 * max_images seq_len = seq_len - max_image_tokens if "audio" in self.get_supported_mm_modalities(): - max_audio_tokens = self.get_max_audio_tokens() * max_audios + 2 * max_audios + max_audio_tokens = self.get_max_audio_tokens( + ) * max_audios + 2 * max_audios seq_len = seq_len - max_audio_tokens max_total_frames = self.get_max_video_frames(seq_len) - + num_frames = max(max_total_frames // max(max_videos, 1), 1) return num_frames @@ -533,19 +532,19 @@ def get_defaul_image_sizes(self, num_slices: int) -> ImageSize: def get_default_audio_pool_step(self): return 2 - + def get_default_audio_sampling_rate(self): return 16000 def get_chunk_length(self) -> int: return self.get_hf_config().audio_chunk_length - + def get_max_audio_tokens_per_chunk(self) -> int: pool_step = self.get_default_audio_pool_step() fbank_feat_in_chunk = 100 cnn_feat_in_chunk = (fbank_feat_in_chunk - 1) // 2 + 1 num_audio_tokens = (cnn_feat_in_chunk - pool_step) // pool_step + 1 - return num_audio_tokens + 2 # + return num_audio_tokens + 2 # def get_max_audio_chunks_with_most_features(self) -> int: return 30 @@ -557,12 +556,12 @@ def get_audio_len_by_num_chunks(self, num_chunks: int) -> int: return int(num_chunks * sampling_rate / num_tokens_per_chunk) + 1 -class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[MiniCPMVProcessingInfo]): +class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[MiniCPMVProcessingInfo] + ): + def get_dummy_processor_inputs( - self, - seq_len: int, - mm_counts: Mapping[str, int] - ) -> ProcessorInputs: + self, seq_len: int, mm_counts: Mapping[str, + int]) -> ProcessorInputs: num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) num_audios = mm_counts.get("audio", 0) @@ -578,41 +577,36 @@ def get_dummy_processor_inputs( self.info.get_default_audio_sampling_rate() mm_data = { - "image": self._get_dummy_images( - width=image_width, - height=image_height, - num_images=num_images - ), - "video": [self._get_dummy_images( - width=video_width, - height=video_height, - num_images=num_video_frames - )] * num_videos, - "audio": self._get_dummy_audios( - length=audio_len, - num_audios=num_audios - ) + "image": + self._get_dummy_images(width=image_width, + height=image_height, + num_images=num_images), + "video": [ + self._get_dummy_images(width=video_width, + height=video_height, + num_images=num_video_frames) + ] * num_videos, + "audio": + self._get_dummy_audios(length=audio_len, num_audios=num_audios) } - + image_prompt_texts = self.info.image_pattern * num_images video_prompt_texts = self.info.video_pattern * num_videos audio_prompt_texts = self.info.audio_pattern * num_audios - return ProcessorInputs( - prompt_text=image_prompt_texts + video_prompt_texts + audio_prompt_texts, - mm_data=mm_data - ) + return ProcessorInputs(prompt_text=image_prompt_texts + + video_prompt_texts + audio_prompt_texts, + mm_data=mm_data) + +class MiniCPMVMultiModalProcessor( + BaseMultiModalProcessor[MiniCPMVProcessingInfo]): -class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[MiniCPMVProcessingInfo]): def _get_data_parser(self) -> MultiModalDataParser: return MiniCPMVMultiModalDataParser() - def get_slice_image_placeholder( - self, - image_size: ImageSize, - **kwargs - ) -> str: + def get_slice_image_placeholder(self, image_size: ImageSize, + **kwargs) -> str: image_processor = self.info.get_image_processor() version = self.info.get_model_version() if version == (2, 0) or version == (2, 5): @@ -620,30 +614,30 @@ def get_slice_image_placeholder( return image_processor.get_slice_image_placeholder( image_size, **kwargs) - def get_image_prompt_texts(self, image_size: ImageSize, image_idx: int = 0) -> str: - prompt_texts = self.get_slice_image_placeholder( - image_size, - image_idx=image_idx - ) + def get_image_prompt_texts(self, + image_size: ImageSize, + image_idx: int = 0) -> str: + prompt_texts = self.get_slice_image_placeholder(image_size, + image_idx=image_idx) return prompt_texts - def get_video_prompt_texts(self, image_size: ImageSize, num_frames: int) -> str: + def get_video_prompt_texts(self, image_size: ImageSize, + num_frames: int) -> str: prompt_texts = "".join([ self.get_slice_image_placeholder( - image_size=image_size, - image_idx=0, + image_size=image_size, + image_idx=0, max_slice_nums=self.info.get_video_max_slice_num(), - use_image_id=False - ) for image_idx in range(num_frames) + use_image_id=False) for image_idx in range(num_frames) ]) return prompt_texts - - def get_audio_prompt_texts(self, audio_lens: int, chunk_input: bool = True, chunk_length: int = 1): + + def get_audio_prompt_texts(self, + audio_lens: int, + chunk_input: bool = True, + chunk_length: int = 1): return self.info.get_hf_processor().get_audio_placeholder( - audio_lens, - chunk_input, - chunk_length - ) + audio_lens, chunk_input, chunk_length) def get_special_tokens(self): tokenizer = self.info.get_tokenizer() @@ -652,11 +646,15 @@ def get_special_tokens(self): "im_end_id": torch.tensor(tokenizer.im_end_id) } if hasattr(tokenizer, "slice_start_id"): - special_tokens["slice_start_id"] = torch.tensor(tokenizer.slice_start_id) - special_tokens["slice_end_id"] = torch.tensor(tokenizer.slice_end_id) + special_tokens["slice_start_id"] = torch.tensor( + tokenizer.slice_start_id) + special_tokens["slice_end_id"] = torch.tensor( + tokenizer.slice_end_id) if hasattr(tokenizer, "audio_start_id"): - special_tokens["audio_start_id"] = torch.tensor(tokenizer.audio_start_id) - special_tokens["audio_end_id"] = torch.tensor(tokenizer.audio_end_id) + special_tokens["audio_start_id"] = torch.tensor( + tokenizer.audio_start_id) + special_tokens["audio_end_id"] = torch.tensor( + tokenizer.audio_end_id) return special_tokens @staticmethod @@ -665,20 +663,17 @@ def repack_processor_outputs(outputs: Any) -> BatchFeature: outputs = {key: outputs[key][0] for key in valid_keys} return outputs - def process_images( - self, - mm_data: Mapping[str, object], - mm_kwargs: Mapping[str, object] - ) -> Dict[str, object]: + def process_images(self, mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object]) -> Dict[str, object]: images = mm_data.pop("images", []) image_embeds = mm_data.pop("image_embeds", []) if isinstance(images, (list, torch.Tensor)) and len(images) > 0: image_outputs = super()._call_hf_processor( prompt=self.info.image_pattern * len(images), mm_data={"images": images}, - mm_kwargs=mm_kwargs - ) - image_outputs = MiniCPMVMultiModalProcessor.repack_processor_outputs(image_outputs) + mm_kwargs=mm_kwargs) + image_outputs = MiniCPMVMultiModalProcessor.\ + repack_processor_outputs(image_outputs) elif len(image_embeds) > 0: image_sizes = mm_data.pop("image_sizes", None) image_outputs = { @@ -688,12 +683,9 @@ def process_images( else: image_outputs = {} return image_outputs - - def process_videos( - self, - mm_data: Mapping[str, object], - mm_kwargs: Mapping[str, object] - ): + + def process_videos(self, mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object]): videos = mm_data.pop("videos", []) video_embeds = mm_data.pop("video_embeds", []) if len(videos) > 0 and isinstance(videos[0], Image.Image): @@ -710,17 +702,18 @@ def process_videos( prompt=self.info.image_pattern * len(video), mm_data={"images": video}, mm_kwargs={ - **mm_kwargs, - "max_slice_nums": self.info.get_video_max_slice_num() - } - ) + **mm_kwargs, "max_slice_nums": + self.info.get_video_max_slice_num() + }) video_outputs["num_frames"].append(len(video)) for key in single_video_outputs: if "video_" + key in video_outputs: if key == "image_sizes": - video_outputs["video_" + key].append(single_video_outputs[key][0][0]) + video_outputs["video_" + key].append( + single_video_outputs[key][0][0]) else: - video_outputs["video_" + key] += single_video_outputs[key][0] + video_outputs["video_" + + key] += single_video_outputs[key][0] elif len(video_embeds): image_sizes = mm_data.pop("image_sizes", None) num_frames = mm_data.pop("num_frames", None) @@ -733,11 +726,8 @@ def process_videos( video_outputs = {} return video_outputs - def process_audios( - self, - mm_data: Mapping[str, object], - mm_kwargs: Mapping[str, object] - ): + def process_audios(self, mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object]): audios = mm_data.pop("audios", []) audio_embeds = mm_data.pop("audio_embeds", []) if isinstance(audios, (list, torch.Tensor)) and len(audios) > 0: @@ -751,33 +741,31 @@ def process_audios( single_audio_outputs = super()._call_hf_processor( prompt=self.info.audio_pattern, mm_data={ - "audios": audio, + "audios": audio, "chunk_input": True }, - mm_kwargs=mm_kwargs - ) + mm_kwargs=mm_kwargs) audio_outputs["audio_lens"].append(len(audio)) audio_outputs["audio_features"].append( - single_audio_outputs["audio_features"] - ) + single_audio_outputs["audio_features"]) audio_outputs["audio_num_segments"].append( - len(single_audio_outputs["audio_feature_lens"][0]) - ) + len(single_audio_outputs["audio_feature_lens"][0])) audio_outputs["audio_feature_lens"] += \ single_audio_outputs["audio_feature_lens"] - audio_outputs["audio_features"] = torch.cat(audio_outputs["audio_features"]) - audio_outputs["audio_feature_lens"] = torch.cat(audio_outputs["audio_feature_lens"]) + audio_outputs["audio_features"] = torch.cat( + audio_outputs["audio_features"]) + audio_outputs["audio_feature_lens"] = torch.cat( + audio_outputs["audio_feature_lens"]) elif len(audio_embeds): audio_outputs = { "audio_lens": [ self.info.get_audio_len_by_num_chunks( - sum(chunk_embeds.shape[0] for chunk_embeds in single_audio_embeds) - ) + sum(chunk_embeds.shape[0] + for chunk_embeds in single_audio_embeds)) for single_audio_embeds in audio_embeds ], "audio_embeds": [ - chunk_embeds - for single_audio_embeds in audio_embeds + chunk_embeds for single_audio_embeds in audio_embeds for chunk_embeds in single_audio_embeds ], "audio_num_segments": [ @@ -801,11 +789,7 @@ def _call_hf_processor( image_outputs = self.process_images(mm_data, mm_kwargs) video_outputs = self.process_videos(mm_data, mm_kwargs) audio_outputs = self.process_audios(mm_data, mm_kwargs) - counts = { - "image": 0, - "video": 0, - "audio": 0 - } + counts = {"image": 0, "video": 0, "audio": 0} num_image_slices = [] num_video_slices = [] num_audio_slices = [] @@ -813,37 +797,34 @@ def _call_hf_processor( image_orders_in_mm_data = [] audio_orders_in_mm_data = [] matches = re.findall(r"\(<(image|video|audio)>./\)", prompt) - chunks = re.split(r"\(<(?:image|video|audio)>./\)", prompt) + chunks = re.split( + r"\(<(?:image|video|audio)>./\)", prompt) new_prompt = chunks[0] for idx, item in enumerate(matches): if item == "image": image_orders_in_mm_data.append(idx) - num_image_slices.append(self.info.get_image_slice_nums( - image_outputs["image_sizes"][counts[item]], - self.info.get_max_slice_num() - )) + num_image_slices.append( + self.info.get_image_slice_nums( + image_outputs["image_sizes"][counts[item]], + self.info.get_max_slice_num())) new_prompt += self.get_image_prompt_texts( - image_outputs["image_sizes"][counts[item]], - counts[item] - ) + image_outputs["image_sizes"][counts[item]], counts[item]) elif item == "video": video_orders_in_mm_data.append(idx) - num_video_slices.append(self.info.get_image_slice_nums( - video_outputs["video_image_sizes"][counts[item]], - self.info.get_video_max_slice_num() - ) * video_outputs["num_frames"][counts[item]]) + num_video_slices.append( + self.info.get_image_slice_nums( + video_outputs["video_image_sizes"][counts[item]], + self.info.get_video_max_slice_num()) * + video_outputs["num_frames"][counts[item]]) new_prompt += self.get_video_prompt_texts( video_outputs["video_image_sizes"][counts[item]], - video_outputs["num_frames"][counts[item]] - ) - else: # audio + video_outputs["num_frames"][counts[item]]) + else: # audio audio_orders_in_mm_data.append(idx) num_audio_slices.append( - audio_outputs["audio_num_segments"][counts[item]] - ) + audio_outputs["audio_num_segments"][counts[item]]) new_prompt += self.get_audio_prompt_texts( - audio_outputs["audio_lens"][counts[item]] - ) + audio_outputs["audio_lens"][counts[item]]) counts[item] += 1 new_prompt += chunks[idx + 1] @@ -852,10 +833,8 @@ def _call_hf_processor( def get_slices(num_slices: List[int]): slice_idices = [0] + list(accumulate(num_slices)) - slices = [ - (slice_idices[i], slice_idices[i + 1]) - for i in range(len(num_slices)) - ] + slices = [(slice_idices[i], slice_idices[i + 1]) + for i in range(len(num_slices))] return slices return { @@ -872,11 +851,9 @@ def get_slices(num_slices: List[int]): } def _get_prompt_replacements( - self, - mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, Any], - out_mm_kwargs: MultiModalKwargs - ) -> List[PromptReplacement]: + self, mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargs) -> List[PromptReplacement]: placeholder = { "image": self.info.image_pattern, "video": self.info.video_pattern, @@ -886,32 +863,27 @@ def _get_prompt_replacements( def get_replacement_minicpmv(item_idx: int, modality: str): if modality == "image": return self.get_image_prompt_texts( - mm_items["image"].get_image_size(item_idx), - item_idx - ) + mm_items["image"].get_image_size(item_idx), item_idx) elif modality == "video": return self.get_video_prompt_texts( mm_items["video"].get_frame_size(item_idx), - mm_items["video"].get_num_frames(item_idx) - ) - else: # audio + mm_items["video"].get_num_frames(item_idx)) + else: # audio if isinstance(mm_items["audio"], MiniCPMVAudioEmbeddingItems): single_audio_embeds = mm_items["audio"].get(item_idx) audio_len = self.info.get_audio_len_by_num_chunks( - sum(chunk_embeds.shape[0] for chunk_embeds in single_audio_embeds) - ) + sum(chunk_embeds.shape[0] + for chunk_embeds in single_audio_embeds)) return self.get_audio_prompt_texts(audio_len) return self.get_audio_prompt_texts( - len(mm_items["audio"].get(item_idx)) - ) - + len(mm_items["audio"].get(item_idx))) + return [ - PromptReplacement( - modality=modality, - target=placeholder[modality], - replacement=partial(get_replacement_minicpmv, - modality=modality) - ) for modality in ("image", "video", "audio") + PromptReplacement(modality=modality, + target=placeholder[modality], + replacement=partial(get_replacement_minicpmv, + modality=modality)) + for modality in ("image", "video", "audio") ] def _get_mm_fields_config( @@ -919,11 +891,16 @@ def _get_mm_fields_config( hf_inputs, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: + def get_slices(slices_indices: List[int]): return [slice(*slice_item) for slice_item in slices_indices] - image_slices = get_slices(hf_inputs.get("image_slices", torch.empty(0, 2))) - video_slices = get_slices(hf_inputs.get("video_slices", torch.empty(0, 2))) - audio_slices = get_slices(hf_inputs.get("audio_slices", torch.empty(0, 2))) + + image_slices = get_slices( + hf_inputs.get("image_slices", torch.empty(0, 2))) + video_slices = get_slices( + hf_inputs.get("video_slices", torch.empty(0, 2))) + audio_slices = get_slices( + hf_inputs.get("audio_slices", torch.empty(0, 2))) return dict( pixel_values=MultiModalFieldConfig.flat("image", image_slices), image_sizes=MultiModalFieldConfig.batched("image"), @@ -931,18 +908,19 @@ def get_slices(slices_indices: List[int]): image_slices=MultiModalFieldConfig.batched("image"), image_orders_in_mm_data=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.flat("image", image_slices), - video_pixel_values=MultiModalFieldConfig.flat("video", video_slices), + video_pixel_values=MultiModalFieldConfig.flat( + "video", video_slices), video_image_sizes=MultiModalFieldConfig.batched("video"), video_tgt_sizes=MultiModalFieldConfig.flat("video", video_slices), video_orders_in_mm_data=MultiModalFieldConfig.batched("video"), video_embeds=MultiModalFieldConfig.flat("video", video_slices), video_slices=MultiModalFieldConfig.batched("video"), audio_features=MultiModalFieldConfig.flat("audio", audio_slices), - audio_feature_lens=MultiModalFieldConfig.flat("audio", audio_slices), + audio_feature_lens=MultiModalFieldConfig.flat( + "audio", audio_slices), audio_slices=MultiModalFieldConfig.batched("audio"), audio_orders_in_mm_data=MultiModalFieldConfig.batched("audio"), - audio_embeds=MultiModalFieldConfig.flat("audio", audio_slices) - ) + audio_embeds=MultiModalFieldConfig.flat("audio", audio_slices)) def apply( self, @@ -956,15 +934,11 @@ def apply( if "image" in result["mm_placeholders"] and \ self.info.get_model_version() in [(2, 6), (2, "6O")]: result["mm_placeholders"]["image"] = [ - PlaceholderRange( - offset=p["offset"] + 3 + idx // 10, - length=p["length"] - 3 - idx // 10 - ) + PlaceholderRange(offset=p["offset"] + 3 + idx // 10, + length=p["length"] - 3 - idx // 10) for idx, p in enumerate(result["mm_placeholders"]["image"]) ] - result["mm_kwargs"].update( - **self.get_special_tokens() - ) + result["mm_kwargs"].update(**self.get_special_tokens()) return result @@ -1046,12 +1020,9 @@ def get_embedding_with_vision( return vlm_embedding, vision_hidden_states - def get_embedding_with_audios( - self, - vlm_embedding: torch.Tensor, - audio_inputs: Optional[MiniCPMVAudioInputs], - chunk_length: int - ) -> torch.Tensor: + def get_embedding_with_audios(self, vlm_embedding: torch.Tensor, + audio_inputs: Optional[MiniCPMVAudioInputs], + chunk_length: int) -> torch.Tensor: device, dtype = vlm_embedding.device, vlm_embedding.dtype if audio_inputs["type"] == "audio_embeds": audio_embeddings = audio_inputs["data"] @@ -1060,32 +1031,31 @@ def get_embedding_with_audios( for i in range(len(audio_embeddings)) ] else: - audio_embeddings = self.get_audio_hidden_states(audio_inputs, chunk_length)[0] + audio_embeddings = self.get_audio_hidden_states( + audio_inputs, chunk_length)[0] if audio_embeddings is None or len(audio_embeddings) == 0: return vlm_embedding audio_bounds = audio_inputs["audio_bounds"] if self.config.chunk_input: - audio_embs = torch.cat(audio_embeddings, dim=0).to( - device=device, dtype=dtype - ) + audio_embs = torch.cat(audio_embeddings, dim=0).to(device=device, + dtype=dtype) audio_start_pos = 0 for bound in audio_bounds: audio_len = bound[1] - bound[0] - vlm_embedding[bound[0] : bound[1]] = audio_embs[ - audio_start_pos : audio_start_pos + audio_len, : - ] + vlm_embedding[bound[0]:bound[1]] = audio_embs[ + audio_start_pos:audio_start_pos + audio_len, :] audio_start_pos += audio_len else: for embs, bound in zip(audio_embeddings, audio_bounds): - audio_indices = torch.arange(bound[0], bound[1], dtype=torch.long).to( - device - ) + audio_indices = torch.arange(bound[0], + bound[1], + dtype=torch.long).to(device) if embs.shape[0] != len(audio_indices): raise ValueError( - f"Shape mismatch: Trying to assign embeddings of shape {embs.shape} " - f"to input indices of length {len(audio_indices)}" - ) + "Shape mismatch: Trying to assign embeddings " + f"of shape {embs.shape} " + f"to input indices of length {len(audio_indices)}") vlm_embedding[audio_indices] = embs.to(dtype) return vlm_embedding @@ -1117,20 +1087,17 @@ def _get_image_bounds( image_end_tokens[:valid_image_nums].unsqueeze(-1), ]) - def _get_audio_bounds( - self, - input_ids: torch.Tensor, - audio_start_id: torch.Tensor, - audio_end_id: torch.Tensor - ) -> torch.Tensor: + def _get_audio_bounds(self, input_ids: torch.Tensor, + audio_start_id: torch.Tensor, + audio_end_id: torch.Tensor) -> torch.Tensor: audio_start_tokens, = torch.where(input_ids == audio_start_id[0]) audio_start_tokens += 1 audio_end_tokens, = torch.where(input_ids == audio_end_id[0]) valid_audio_nums = max(len(audio_start_tokens), len(audio_end_tokens)) return torch.hstack([ - audio_start_tokens[:valid_audio_nums].unsqueeze(-1), + audio_start_tokens[:valid_audio_nums].unsqueeze(-1), audio_end_tokens[:valid_audio_nums].unsqueeze(-1) - ]) + ]) def _parse_and_validate_image_inputs( self, @@ -1139,9 +1106,8 @@ def _parse_and_validate_image_inputs( ) -> Optional[MiniCPMVImageInputs]: mm_data = { "image": { - key: kwargs.pop(key, []) for key in [ - "pixel_values", "tgt_sizes", "image_slices" - ] + key: kwargs.pop(key, []) + for key in ["pixel_values", "tgt_sizes", "image_slices"] }, "video": { "pixel_values": kwargs.pop("video_pixel_values", []), @@ -1157,20 +1123,22 @@ def _parse_and_validate_image_inputs( modality: kwargs.pop(f"{modality}_orders_in_mm_data", None) for modality in ["image", "video", "audio"] } - batch_size = max(len(mm_data["image"]["pixel_values"]), + batch_size = max(len(mm_data["image"]["pixel_values"]), len(mm_data["video"]["pixel_values"])) image_embeds = kwargs.pop("image_embeds", None) video_embeds = kwargs.pop("video_embeds", None) if image_embeds is not None and video_embeds is not None: - raise ValueError("Incorrect inputs for vision embeddings. " - "Image embeds and video embeds can not exist simultaneously.") + raise ValueError( + "Incorrect inputs for vision embeddings. " + "Image embeds and video embeds can not exist simultaneously.") if video_embeds is not None: image_embeds = video_embeds if image_embeds is not None: if not isinstance(image_embeds, (torch.Tensor, list)): raise ValueError(f"Incorrect type of image embeds. " f"Got type: {type(image_embeds)}") - image_embeds = torch.concat([image_embeds[i] for i in range((len(image_embeds)))]) + image_embeds = torch.concat( + [image_embeds[i] for i in range(len(image_embeds))]) return MiniCPMVImageEmbeddingInputs( image_bounds=self._get_image_bounds(input_ids, im_start_id, @@ -1181,17 +1149,24 @@ def _parse_and_validate_image_inputs( ) for modality, modality_mm_data in mm_data.items(): - if not isinstance(modality_mm_data["pixel_values"], (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(modality_mm_data['pixel_values'])}") - - if not isinstance(modality_mm_data["tgt_sizes"], (torch.Tensor, list)): - raise ValueError("Incorrect type of target sizes. " - f"Got type: {type(modality_mm_data['tgt_sizes'])}") - - if len(modality_mm_data["pixel_values"]) != len(modality_mm_data["tgt_sizes"]): - raise ValueError("Inconsistent batch lengths, found: " - f"{len(modality_mm_data['pixel_values'])} vs. {len(modality_mm_data['tgt_sizes'])}") + if not isinstance(modality_mm_data["pixel_values"], + (torch.Tensor, list)): + raise ValueError( + "Incorrect type of pixel values. " + f"Got type: {type(modality_mm_data['pixel_values'])}") + + if not isinstance(modality_mm_data["tgt_sizes"], + (torch.Tensor, list)): + raise ValueError( + "Incorrect type of target sizes. " + f"Got type: {type(modality_mm_data['tgt_sizes'])}") + + if len(modality_mm_data["pixel_values"]) != len( + modality_mm_data["tgt_sizes"]): + raise ValueError( + "Inconsistent batch lengths, found: " + f"{len(modality_mm_data['pixel_values'])} vs. " + f"{len(modality_mm_data['tgt_sizes'])}") pixel_values_flat: List[torch.Tensor] = [] tgt_sizes_flat: List[torch.Tensor] = [] @@ -1205,19 +1180,21 @@ def _parse_and_validate_image_inputs( for media_type in ["image", "video", "audio"] for pos, index in enumerate(orders_in_mm_data_b[media_type]) ] - mm_data_indices = [ - (pos, modality) for index, (pos, modality) in - sorted(mm_data_indices, key=lambda x: x[0]) - ] + mm_data_indices = [(pos, modality) for index, ( + pos, modality) in sorted(mm_data_indices, key=lambda x: x[0])] for pos, modality in mm_data_indices: if modality == "image": slice_index = mm_data[modality]["image_slices"][b][pos] - pixel_values_flat += mm_data[modality]["pixel_values"][b][slice_index[0]:slice_index[1]] - tgt_sizes_flat += mm_data[modality]["tgt_sizes"][b][slice_index[0]:slice_index[1]] + pixel_values_flat += mm_data[modality]["pixel_values"][b][ + slice_index[0]:slice_index[1]] + tgt_sizes_flat += mm_data[modality]["tgt_sizes"][b][ + slice_index[0]:slice_index[1]] elif modality == "video": slice_index = mm_data[modality]["video_slices"][b][pos] - pixel_values_flat += mm_data[modality]["pixel_values"][b][slice_index[0]:slice_index[1]] - tgt_sizes_flat += mm_data[modality]["tgt_sizes"][b][slice_index[0]:slice_index[1]] + pixel_values_flat += mm_data[modality]["pixel_values"][b][ + slice_index[0]:slice_index[1]] + tgt_sizes_flat += mm_data[modality]["tgt_sizes"][b][ + slice_index[0]:slice_index[1]] # NOTE: Input IDs does not contain image tokens during memory profiling, # so we allow it to be empty @@ -1242,59 +1219,46 @@ def _parse_and_validate_image_inputs( ) def _parse_and_validate_audio_inputs( - self, - input_ids: torch.Tensor, - **kwargs: object - ) -> Tuple[MiniCPMVImageInputs]: + self, input_ids: torch.Tensor, + **kwargs: object) -> Tuple[MiniCPMVImageInputs]: audio_features = kwargs.pop("audio_features", []) audio_feature_lens = kwargs.pop("audio_feature_lens", []) audio_embeds = kwargs.pop("audio_embeds", None) audio_start_id = kwargs.pop("audio_start_id", None) audio_end_id = kwargs.pop("audio_end_id", None) if audio_embeds is not None: - audio_embeds = [audio_embeds[i][j] - for i in range(len(audio_embeds)) - for j in range(len(audio_embeds[i]))] + audio_embeds = [ + audio_embeds[i][j] for i in range(len(audio_embeds)) + for j in range(len(audio_embeds[i])) + ] return MiniCPMVAudioEmbeddingInputs( - audio_bounds=self._get_audio_bounds(input_ids, - audio_start_id, audio_end_id), + audio_bounds=self._get_audio_bounds(input_ids, audio_start_id, + audio_end_id), data=audio_embeds, - type="audio_embeds" - ) + type="audio_embeds") if len(audio_features) > 0: - audio_features = torch.cat([ - item for item in audio_features - ]) - audio_feature_lens = torch.cat([ - item for item in audio_feature_lens - ]) + audio_features = torch.cat([item for item in audio_features]) + audio_feature_lens = torch.cat( + [item for item in audio_feature_lens]) return MiniCPMVAudioFeatureInputs( - audio_bounds=self._get_audio_bounds(input_ids, - audio_start_id, audio_end_id), + audio_bounds=self._get_audio_bounds(input_ids, audio_start_id, + audio_end_id), data=audio_features, audio_feature_lens=audio_feature_lens, - type="audio_features" - ) + type="audio_features") return None - - def _parse_and_validate_inputs( - self, - input_ids: torch.Tensor, - **kwargs: object - ): + + def _parse_and_validate_inputs(self, input_ids: torch.Tensor, + **kwargs: object): image_inputs = self._parse_and_validate_image_inputs( - input_ids, - **kwargs - ) - if not any("audio" in key for key in kwargs.keys()): + input_ids, **kwargs) + if not any("audio" in key for key in kwargs): return image_inputs, None audio_inputs = self._parse_and_validate_audio_inputs( - input_ids, - **kwargs - ) + input_ids, **kwargs) return image_inputs, audio_inputs - + def forward( self, input_ids: torch.Tensor, @@ -1309,14 +1273,13 @@ def forward( else: image_inputs, audio_inputs = \ self._parse_and_validate_inputs(input_ids, **kwargs) - vlm_embeddings, _ = self.get_embedding_with_vision(input_ids, image_inputs) + vlm_embeddings, _ = self.get_embedding_with_vision( + input_ids, image_inputs) if audio_inputs is not None: vlm_embeddings = self.get_embedding_with_audios( - vlm_embeddings, - audio_inputs, - self.config.audio_chunk_length - ) + vlm_embeddings, audio_inputs, + self.config.audio_chunk_length) # always pass the input via `inputs_embeds` # to make sure the computation graph is consistent @@ -1395,8 +1358,7 @@ def get_vision_hidden_states(self, data: MiniCPMVImageInputs) -> torch.Tensor: raise NotImplementedError - def get_audio_hidden_states(self, - data: MiniCPMVAudioInputs, + def get_audio_hidden_states(self, data: MiniCPMVAudioInputs, chunk_length: int) -> torch.Tensor: raise NotImplementedError @@ -1725,11 +1687,16 @@ def get_vision_hidden_states(self, class MultiModalProjector(nn.Module): + def __init__(self, in_dim, out_dim): super().__init__() - self.linear1 = nn.Linear(in_features=in_dim, out_features=out_dim, bias=True) + self.linear1 = nn.Linear(in_features=in_dim, + out_features=out_dim, + bias=True) self.relu = nn.ReLU() - self.linear2 = nn.Linear(in_features=out_dim, out_features=out_dim, bias=True) + self.linear2 = nn.Linear(in_features=out_dim, + out_features=out_dim, + bias=True) def forward(self, audio_features): hidden_states = self.relu(self.linear1(audio_features)) @@ -1737,18 +1704,19 @@ def forward(self, audio_features): return hidden_states -# Copied from transformers.models.whisper.modeling_whisper.WhisperEncoderLayer and add use_cache for streaming inference class MiniCPMWhisperEncoderLayer(nn.Module): + def __init__(self, config: WhisperConfig, layer_idx: int = None): super().__init__() self.embed_dim = config.d_model - self.self_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation]( - embed_dim=self.embed_dim, - num_heads=config.encoder_attention_heads, - dropout=config.attention_dropout, - config=config, - layer_idx=layer_idx, - ) + self.self_attn = WHISPER_ATTENTION_CLASSES[ + config._attn_implementation]( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + config=config, + layer_idx=layer_idx, + ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -1766,24 +1734,6 @@ def forward( past_key_values: Optional[EncoderDecoderCache] = None, use_cache: Optional[bool] = False, ) -> torch.Tensor: - r""" - Args: - hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, embed_dim)`): - Hidden states to be fed into the encoder layer. - attention_mask (`torch.FloatTensor` of shape `(batch_size, 1, tgt_len, src_len)`): - Attention mask where padding elements are indicated by large negative values. - layer_head_mask (`torch.FloatTensor` of shape `(encoder_attention_heads,)`): - Mask to nullify selected heads of the attention modules. - output_attentions (`bool`, *optional*): - Whether or not to return the attention weights. - past_key_values (`EncoderDecoderCache`, *optional*): - Past key-value pairs used for incremental decoding. - use_cache (`bool`, *optional*): - Whether or not to return updated `past_key_values` for caching. - - Returns: - A tuple of shape `(hidden_states, optional(attn_weights), optional(past_key_values))`. - """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states, attn_weights, past_key_values = self.self_attn( @@ -1793,42 +1743,50 @@ def forward( output_attentions=output_attentions, past_key_value=past_key_values, ) - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = nn.functional.dropout(hidden_states, + p=self.dropout, + training=self.training) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = nn.functional.dropout(hidden_states, + p=self.activation_dropout, + training=self.training) hidden_states = self.fc2(hidden_states) - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = nn.functional.dropout(hidden_states, + p=self.dropout, + training=self.training) hidden_states = residual + hidden_states if hidden_states.dtype == torch.float16 and ( - torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() - ): + torch.isinf(hidden_states).any() + or torch.isnan(hidden_states).any()): clamp_value = torch.finfo(hidden_states.dtype).max - 1000 - hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + hidden_states = torch.clamp(hidden_states, + min=-clamp_value, + max=clamp_value) - outputs = (hidden_states,) + outputs = (hidden_states, ) if output_attentions: - outputs += (attn_weights,) + outputs += (attn_weights, ) if use_cache: - outputs += (past_key_values,) + outputs += (past_key_values, ) return outputs - -# Copied from from transformers.models.whisper.modeling_whisper.WhisperEncoder and add use_cache for streaming inference + class MiniCPMWhisperEncoder(WhisperEncoder): def __init__(self, config: WhisperConfig): super().__init__(config) - self.layers = nn.ModuleList( - [MiniCPMWhisperEncoderLayer(config, layer_idx=i) for i in range(config.encoder_layers)] - ) + self.layers = nn.ModuleList([ + MiniCPMWhisperEncoderLayer(config, layer_idx=i) + for i in range(config.encoder_layers) + ]) def forward( self, @@ -1841,115 +1799,17 @@ def forward( past_key_values: Optional[EncoderDecoderCache] = None, use_cache: Optional[bool] = None, ): - r""" - Forward pass of the Whisper encoder. - - Args: - input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, sequence_length)`): - Float values of log-mel features extracted from the raw audio waveform. Typically generated - by a feature extractor (e.g., `WhisperFeatureExtractor`) that processes `.flac` or `.wav` - files into padded 2D mel spectrogram frames. These features are projected via convolution layers - (`conv1` and `conv2`) and then transformed into embeddings for the encoder. - - attention_mask (`torch.Tensor`, *optional*): - Not used by Whisper for masking `input_features`, but included for API compatibility with - other models. If provided, it is simply ignored within the model. By default, Whisper - effectively ignores silence in the input log-mel spectrogram. - - head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): - Mask to nullify selected attention heads. The elements should be either 1 or 0, where: - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked** (i.e., the attention head is dropped). - - output_attentions (`bool`, *optional*): - Whether or not to return the attention tensors of all encoder layers. If set to `True`, the - returned tuple (or `BaseModelOutputWithPast`) will contain an additional element with - attention weights for each encoder layer. - - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. If set to `True`, the returned - tuple (or `BaseModelOutputWithPast`) will contain a tuple of hidden states, including the - initial embedding output as well as the outputs of each layer. - - return_dict (`bool`, *optional*): - Whether or not to return a `BaseModelOutputWithPast` (a subclass of `ModelOutput`) instead - of a plain tuple. If set to `True`, the output will be a `BaseModelOutputWithPast` object, - otherwise it will be a tuple. - - past_key_values (`EncoderDecoderCache`, *optional*): - When using caching for faster inference, this is an object that stores the key-value pairs - for attention states. If provided, the model will append new states to the existing cache - and return the updated cache. This speeds up sequential decoding or chunked inference. - - - If `past_key_values` is `None`, no past states are used or returned. - - If `past_key_values` is not `None` and `use_cache=True`, the model will use the provided - cache and return the updated cache (as `next_encoder_cache`). - - use_cache (`bool`, *optional*): - Whether or not the model should use caching (`past_key_values`) to speed up processing - during inference. When set to `True`, the model will: - - Inspect and use `past_key_values` if provided. - - Return updated `past_key_values` (under the name `next_encoder_cache` in - `BaseModelOutputWithPast`). - - Returns: - `BaseModelOutputWithPast` or `tuple` (depending on `return_dict`): - If `return_dict=True`, a `BaseModelOutputWithPast` is returned, which contains: - - **last_hidden_state** (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - The output of the final encoder layer. - - **hidden_states** (`tuple(torch.FloatTensor)`, *optional*, returned if `output_hidden_states=True`): - Hidden states of the model at each layer (including the initial projection). - - **attentions** (`tuple(torch.FloatTensor)`, *optional*, returned if `output_attentions=True`): - Attention weights from each encoder layer. - - **past_key_values** (an object of type `EncoderDecoderCache` or `None`, *optional*): - Updated cache of key-value pairs if `use_cache=True`. - - If `return_dict=False`, a tuple is returned, where the format is: - `(last_hidden_state, hidden_states, attentions)`, with `hidden_states` and `attentions` - only present if their respective `output_*` arguments are set to `True`. - - Example: - >>> from transformers import AutoFeatureExtractor, WhisperConfig, WhisperForConditionalGeneration - >>> import torch - - >>> # Load a feature extractor and a Whisper model - >>> feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-tiny.en") - >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") - - >>> # Assume you have audio (list of floats or numpy array) loaded from a file - >>> # Then extract the mel features: - >>> input_features = feature_extractor(audio, sampling_rate=16000, return_tensors="pt").input_features - - >>> # Forward pass - >>> outputs = model.encoder( - ... input_features=input_features, - ... output_hidden_states=True, - ... output_attentions=True, - ... use_cache=True - ... ) - - >>> # Retrieve the last hidden state - >>> last_hidden_state = outputs.last_hidden_state - >>> print(last_hidden_state.shape) - torch.Size([batch_size, seq_length, hidden_size]) - - >>> # Retrieve the intermediate hidden states if output_hidden_states=True - >>> all_encoder_hidden_states = outputs.hidden_states - - >>> # Retrieve attention weights if output_attentions=True - >>> all_encoder_attentions = outputs.attentions - - >>> # Retrieve updated past key values if use_cache=True - >>> encoder_cache = outputs.past_key_values - """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_attentions = output_attentions if output_attentions is not None \ + else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None \ + else self.config.use_return_dict # Ignore copy - input_features = input_features.to(dtype=self.conv1.weight.dtype, device=self.conv1.weight.device) + input_features = input_features.to(dtype=self.conv1.weight.dtype, + device=self.conv1.weight.device) inputs_embeds = nn.functional.gelu(self.conv1(input_features)) inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) @@ -1960,48 +1820,57 @@ def forward( past_key_values_length = 0 if use_cache: if past_key_values is None: - past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + past_key_values = EncoderDecoderCache(DynamicCache(), + DynamicCache()) elif isinstance(past_key_values, list): - past_key_values = EncoderDecoderCache(DynamicCache.from_legacy_cache(past_key_values), DynamicCache()) + past_key_values = EncoderDecoderCache( + DynamicCache.from_legacy_cache(past_key_values), + DynamicCache()) elif isinstance(past_key_values, DynamicCache): - past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) + past_key_values = EncoderDecoderCache(past_key_values, + DynamicCache()) else: pass - past_key_values_length = past_key_values.self_attention_cache.get_usable_length(inputs_embeds.shape[1]) - if inputs_embeds.shape[1] + past_key_values_length > embed_pos.shape[0]: - # logger.warning("seems the audio is longer than 30s. repeating the last part of the audio") - embed_pos_front = embed_pos[past_key_values_length:, :] - embed_pos = torch.cat( - ( - embed_pos_front, - torch.repeat_interleave( - embed_pos[-1, :].unsqueeze(0), - inputs_embeds.shape[1] - embed_pos.shape[0] + past_key_values_length, - dim=0, - ), - ) + past_key_values_length = \ + past_key_values.self_attention_cache.get_usable_length( + inputs_embeds.shape[1] ) + if inputs_embeds.shape[ + 1] + past_key_values_length > embed_pos.shape[0]: + embed_pos_front = embed_pos[past_key_values_length:, :] + embed_pos = torch.cat(( + embed_pos_front, + torch.repeat_interleave( + embed_pos[-1, :].unsqueeze(0), + inputs_embeds.shape[1] - embed_pos.shape[0] + + past_key_values_length, + dim=0, + ), + )) else: - embed_pos = embed_pos[past_key_values_length : inputs_embeds.shape[1] + past_key_values_length, :] + embed_pos = embed_pos[ + past_key_values_length:inputs_embeds.shape[1] + + past_key_values_length, :] else: - embed_pos = embed_pos[: inputs_embeds.shape[1], :] + embed_pos = embed_pos[:inputs_embeds.shape[1], :] hidden_states = inputs_embeds + embed_pos - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = nn.functional.dropout(hidden_states, + p=self.dropout, + training=self.training) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None - # check if head_mask has a correct number of layers specified if desired if head_mask is not None: assert head_mask.size()[0] == ( len(self.layers) - ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + ), \ + f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." # noqa for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + encoder_states = encoder_states + (hidden_states, ) to_drop = False if self.training: dropout_probability = torch.rand([]) @@ -2026,7 +1895,8 @@ def forward( layer_outputs = encoder_layer( hidden_states, attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), + layer_head_mask=(head_mask[idx] + if head_mask is not None else None), output_attentions=output_attentions, past_key_values=past_key_values, use_cache=use_cache, @@ -2035,19 +1905,22 @@ def forward( hidden_states = layer_outputs[0] if use_cache: - next_encoder_cache = layer_outputs[2 if output_attentions else 1] + next_encoder_cache = layer_outputs[ + 2 if output_attentions else 1] else: next_encoder_cache = None if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) + all_attentions = all_attentions + (layer_outputs[1], ) hidden_states = self.layer_norm(hidden_states) if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) + encoder_states = encoder_states + (hidden_states, ) if not return_dict: - return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return tuple( + v for v in [hidden_states, encoder_states, all_attentions] + if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, hidden_states=encoder_states, @@ -2057,8 +1930,6 @@ def forward( class MiniCPMO2_6(MiniCPMV2_6): - # apm_mapper_hf_pattern = ["apm.layers.{}.fc1", "apm.layers.{}.fc2"] - # apm_mapper_vllm_pattern = ["apm.layers.{}.mlp.fc1", "apm.layers.{}.mlp.fc2"] packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -2070,33 +1941,28 @@ class MiniCPMO2_6(MiniCPMV2_6): "up_proj", ], } - - def __init__(self, *, vllm_config: VllmConfig, prefix: str=""): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) self.apm = self.init_audio_module(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "apm")) - def init_audio_module( - self, - *, - vllm_config: VllmConfig, - prefix: str="" - ): + def init_audio_module(self, *, vllm_config: VllmConfig, prefix: str = ""): # Do not use parameters temporarily audio_config = self.config.audio_config model = MiniCPMWhisperEncoder(audio_config) audio_output_dim = int(audio_config.encoder_ffn_dim // 4) - setattr(self, "audio_avg_pooler", - nn.AvgPool1d(self.config.audio_pool_step, stride=self.config.audio_pool_step)) - setattr(self, "audio_projection_layer", - MultiModalProjector(in_dim=audio_output_dim, out_dim=self.embed_dim)) - setattr(self, "audio_encoder_layer", -1) + self.audio_avg_pooler = \ + nn.AvgPool1d(self.config.audio_pool_step, + stride=self.config.audio_pool_step) + self.audio_projection_layer = \ + MultiModalProjector(in_dim=audio_output_dim,out_dim=self.embed_dim) + self.audio_encoder_layer = -1 return model - + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - loader = AutoWeightsLoader(self, - skip_prefixes=["tts"]) + loader = AutoWeightsLoader(self, skip_prefixes=["tts"]) return loader.load_weights(weights) def subsequent_chunk_mask( @@ -2104,75 +1970,41 @@ def subsequent_chunk_mask( size: int, chunk_size: int, num_left_chunks: int = -1, - device: torch.device = torch.device("cpu"), + device: torch.device = CPU_DEVICE, num_lookhead: int = 0, ) -> torch.Tensor: - """Create mask for subsequent steps (size, size) with chunk size, - this is for streaming encoder - - Args: - size (int): size of mask - chunk_size (int): size of chunk - num_left_chunks (int): number of left chunks - <0: use full chunk - >=0: use num_left_chunks - device (torch.device): "cpu" or "cuda" or torch.Tensor.device - - Returns: - torch.Tensor: mask - - Examples: - >>> subsequent_chunk_mask(4, 2) - [[1, 1, 0, 0], - [1, 1, 0, 0], - [1, 1, 1, 1], - [1, 1, 1, 1]] - """ ret = torch.zeros(size, size, device=device, dtype=torch.bool) for i in range(size): if num_left_chunks < 0: start = 0 else: - start = max((i // chunk_size - num_left_chunks) * chunk_size, 0) - ending = min((i // chunk_size + 1) * chunk_size + num_lookhead, size) + start = max((i // chunk_size - num_left_chunks) * chunk_size, + 0) + ending = min((i // chunk_size + 1) * chunk_size + num_lookhead, + size) ret[i, start:ending] = True return ret - def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): - """ - Computes the output length of the convolutional layers and the output length of the audio encoder - """ + def _get_feat_extract_output_lengths(self, + input_lengths: torch.LongTensor): input_lengths_after_cnn = (input_lengths - 1) // 2 + 1 input_lengths_after_pooling = ( - input_lengths_after_cnn - self.config.audio_pool_step - ) // self.config.audio_pool_step + 1 - input_lengths_after_pooling = input_lengths_after_pooling.to(dtype=torch.int32) + input_lengths_after_cnn - + self.config.audio_pool_step) // self.config.audio_pool_step + 1 + input_lengths_after_pooling = input_lengths_after_pooling.to( + dtype=torch.int32) return input_lengths_after_cnn, input_lengths_after_pooling - # Copied from HF repo of MiniCPM-o-2_6, designed for batched inputs and outputs - def get_audio_hidden_states(self, - data: MiniCPMVAudioInputs, + # Copied from HF repo of MiniCPM-o-2_6, + # designed for batched inputs and outputs + def get_audio_hidden_states(self, data: MiniCPMVAudioInputs, chunk_length: int) -> torch.Tensor: - r""" - Extract full audio embeddings with optional chunk-based attention. - - This method computes embeddings for all audio frames at once, either using full attention (when - `chunk_length` is -1) or chunk-based attention (when `chunk_length` is a positive number). It does - not use key-value caching and is suitable for non-streaming inference. - - Args: - data (dict): - - **"audio_features"** (`torch.FloatTensor`): Input mel-spectrograms of shape `(batch_size, 80, frames)`. - - **"audio_feature_lens"** (List[int]): Lengths of each audio segment for each item. - chunk_length (int, optional): Determines whether to use full attention (-1) or chunk-based - attention (>0) during embedding computation. - - Returns: - List[List[torch.Tensor]]: audio embeddings - """ - wavforms = data.get("data", []) # (bs, 80, frames) or [], multi audios need filled in advance - audio_feature_lens_raw = [data.get("audio_feature_lens", [])] # list, [[x1, x2], [y1], [z1]] + wavforms = data.get( + "data", + []) # (bs, 80, frames) or [], multi audios need filled in advance + audio_feature_lens_raw = [data.get("audio_feature_lens", + [])] # list, [[x1, x2], [y1], [z1]] # exist audio if len(wavforms) > 0: @@ -2181,21 +2013,23 @@ def get_audio_hidden_states(self, max_seq_len = (max_mel_seq_len - 1) // 2 + 1 # Create a sequence tensor of shape (batch_size, max_seq_len) - seq_range = ( - torch.arange(0, max_seq_len, dtype=audio_feature_lens.dtype, device=audio_feature_lens.device) - .unsqueeze(0) - .expand(batch_size, max_seq_len) - ) - lengths_expand = audio_feature_lens.unsqueeze(1).expand(batch_size, max_seq_len) + seq_range = (torch.arange( + 0, + max_seq_len, + dtype=audio_feature_lens.dtype, + device=audio_feature_lens.device).unsqueeze(0).expand( + batch_size, max_seq_len)) + lengths_expand = audio_feature_lens.unsqueeze(1).expand( + batch_size, max_seq_len) # Create mask padding_mask = seq_range >= lengths_expand # 1 for padded values - audio_attention_mask_ = padding_mask.view(batch_size, 1, 1, max_seq_len).expand( - batch_size, 1, max_seq_len, max_seq_len - ) + audio_attention_mask_ = padding_mask.view( + batch_size, 1, 1, max_seq_len).expand(batch_size, 1, + max_seq_len, max_seq_len) audio_attention_mask = audio_attention_mask_.to( - dtype=self.apm.conv1.weight.dtype, device=self.apm.conv1.weight.device - ) + dtype=self.apm.conv1.weight.dtype, + device=self.apm.conv1.weight.device) if chunk_length > 0: chunk_num_frame = int(chunk_length * 50) @@ -2205,19 +2039,23 @@ def get_audio_hidden_states(self, num_left_chunks=-1, device=audio_attention_mask_.device, ) - audio_attention_mask_ = torch.logical_or(audio_attention_mask_, torch.logical_not(chunk_mask)) + audio_attention_mask_ = torch.logical_or( + audio_attention_mask_, torch.logical_not(chunk_mask)) audio_attention_mask[audio_attention_mask_] = float("-inf") audio_states = self.apm( - wavforms, output_hidden_states=True, attention_mask=audio_attention_mask - ).hidden_states[self.audio_encoder_layer] + wavforms, + output_hidden_states=True, + attention_mask=audio_attention_mask).hidden_states[ + self.audio_encoder_layer] audio_embeds = self.audio_projection_layer(audio_states) audio_embeds = audio_embeds.transpose(1, 2) audio_embeds = self.audio_avg_pooler(audio_embeds) audio_embeds = audio_embeds.transpose(1, 2) - _, feature_lens_after_pooling = self._get_feat_extract_output_lengths(audio_feature_lens) + _, feature_lens_after_pooling = \ + self._get_feat_extract_output_lengths(audio_feature_lens) num_audio_tokens = feature_lens_after_pooling @@ -2226,13 +2064,14 @@ def get_audio_hidden_states(self, for i in range(len(audio_feature_lens_raw)): target_audio_embeds = [] for _ in range(len(audio_feature_lens_raw[i])): - target_audio_embeds.append(audio_embeds[idx, : num_audio_tokens[idx], :]) + target_audio_embeds.append( + audio_embeds[idx, :num_audio_tokens[idx], :]) idx += 1 final_audio_embeds.append(target_audio_embeds) return final_audio_embeds else: return [] - + _SUPPORT_VERSION = { (2, 0): MiniCPMV2_0, @@ -2244,7 +2083,8 @@ def get_audio_hidden_states(self, @MULTIMODAL_REGISTRY.register_processor(MiniCPMVMultiModalProcessor, info=MiniCPMVProcessingInfo, - dummy_inputs=MiniCPMVDummyInputsBuilder) + dummy_inputs=MiniCPMVDummyInputsBuilder + ) class MiniCPMV(MiniCPMVBaseModel, SupportsMultiModal, SupportsLoRA): """ Different versions of MiniCPMV use different visual encoders and LLMs,