From c2d8dbb89925322e889c70dfe5d9d925e305754b Mon Sep 17 00:00:00 2001 From: hzh Date: Wed, 22 Jan 2025 07:47:21 +0000 Subject: [PATCH] audio language Signed-off-by: hzh --- vllm/model_executor/models/minicpmv.py | 907 ++++++++++++++++++++++--- 1 file changed, 823 insertions(+), 84 deletions(-) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 096320eacd0c7..e23f2fafae08a 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -31,8 +31,13 @@ from PIL import Image from torch import nn import numpy as np -from transformers import BatchFeature -from transformers import PretrainedConfig +from transformers import BatchFeature, PretrainedConfig +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.cache_utils import EncoderDecoderCache, DynamicCache +from transformers.models.whisper.modeling_whisper import (ACT2FN, + WHISPER_ATTENTION_CLASSES, + WhisperConfig, + WhisperEncoder) from typing_extensions import NotRequired from itertools import accumulate @@ -43,6 +48,7 @@ get_2d_sincos_pos_embed) from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader.utils import set_default_torch_dtype +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.minicpm import MiniCPMForCausalLM from vllm.model_executor.models.module_mapping import MultiModelKeys @@ -68,7 +74,7 @@ from .idefics2_vision_model import Idefics2VisionTransformer from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP -from .utils import AutoWeightsLoader, maybe_prefix +from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix RawImageType = Union[Image.Image, torch.Tensor] @@ -116,6 +122,29 @@ class MiniCPMVImageEmbeddingInputs(TypedDict): """ +class MiniCPMVAudioFeatureInputs(TypedDict): + type: Literal["audio_features"] + data: List[torch.Tensor] + """ + Shape: + """ + + audio_feature_lens: torch.Tensor + """ + Shape: + """ + + audio_bounds: torch.Tensor + """ + Shape: + """ + + +MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs, + MiniCPMVImageEmbeddingInputs] +MiniCPMVAudioInputs = Union[MiniCPMVAudioFeatureInputs] + + class MiniCPMVEmbeddingItems(ModalityDataItems[dict[str, torch.Tensor], dict[str, torch.Tensor]]): def __init__(self, data: Dict, modality: str) -> None: @@ -198,9 +227,6 @@ def get_num_frames(self, index: int) -> int: return self.data["num_frames"][index] -MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs, - MiniCPMVImageEmbeddingInputs] - DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6) @@ -306,7 +332,8 @@ def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]: if config.hidden_size == 2304 and config.query_num == 64: return (2, 0) return (2, 5) - + elif "MiniCPMO" in config.architectures: + return (2, "6O") version_str = str(version_float) return tuple(int(x) for x in version_str.split(".")) @@ -359,29 +386,41 @@ def get_model_version(self): return get_version_by_config(self.get_hf_config()) def get_supported_mm_modalities(self) -> List[str]: - if self.get_model_version() == (2, 6): + if self.get_model_version() == (2, "6O"): + return ["image", "video", "audio"] + elif self.get_model_version() == (2, 6): return ["image", "video"] else: return ["image"] def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + if self.get_model_version() == (2, "6O"): + return {"image": None, "video": None, "audio": None} if self.get_model_version() == (2, 6): return {"image": None, "video": None} else: return {"image": None} def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: - if self.get_model_version() == (2, 6): - return { - "image": self.get_max_image_tokens(), - "video": self.get_max_video_tokens(seq_len), - # "audio": self.get_max_audio_tokens() - } - else: - return { - "image": self.get_max_image_tokens() - } + mm_max_tokens = { + "image": self.get_max_image_tokens() + } + if self.get_model_version() == (2, "6O"): + mm_max_tokens["audio"] = self.get_max_audio_tokens() + if self.get_model_version() in [(2, 6), (2, "6O")]: + mm_max_tokens["video"] = self.get_max_video_tokens(seq_len) + return mm_max_tokens + + def get_max_video_frame_tokens(self) -> int: + frame_size = self.get_video_frame_size_with_most_features() + return self.get_num_image_tokens(frame_size, self.get_video_max_slice_num()) + + def get_max_video_tokens(self, seq_len: int) -> int: + return self.get_max_video_frame_tokens() * self.get_num_frames_with_most_features(seq_len) + def get_max_audio_tokens(self) -> int: + return self.get_max_audio_tokens_per_chunk() * self.get_max_audio_chunks_with_most_features() + def get_slice_query_num(self) -> int: hf_config = self.get_hf_config() query_num = getattr(hf_config, "query_num", 64) @@ -393,7 +432,7 @@ def get_max_slice_num(self) -> int: return max_slice_num def get_sliced_grid(self, image_size, max_slice_num) -> Tuple[int, int]: - if self.get_model_version() == (2, 6): + if self.get_model_version() in [(2, 6), (2, "6O")]: slice_grid = self.get_image_processor().get_sliced_grid(image_size, max_slice_num) else: slice_grid = self.get_image_processor().get_sliced_grid(image_size) @@ -403,7 +442,7 @@ def get_num_image_tokens(self, image_size: ImageSize, max_slice_num: int) -> int slice_grid = self.get_sliced_grid(image_size, max_slice_num) num_tokens = self.get_slice_query_num() + 2 # ( * query_num) if slice_grid is not None: - if self.get_model_version() == (2, 6): + if self.get_model_version() in [(2, 6), (2, "6O")]: num_additional_tokens = 0 # ( * query_num) else: num_additional_tokens = 2 # ( * query_num) @@ -419,33 +458,21 @@ def get_max_image_tokens(self) -> int: image_size = self.get_image_size_with_most_features() return self.get_num_image_tokens(image_size, self.get_max_slice_num()) - def get_max_video_frame_tokens(self) -> int: - frame_size = self.get_video_frame_size_with_most_features() - return self.get_num_image_tokens(frame_size, self.get_video_max_slice_num()) - - def get_max_video_tokens(self, seq_len: int) -> int: - return self.get_max_video_frame_tokens() * self.get_num_frames_with_most_features(seq_len) - - def get_max_audio_tokens(self) -> int: - pass - def get_image_size_with_most_features(self) -> ImageSize: # Result in the max possible feature size (h:w = 9:1) - hf_config = self.get_hf_config() - image_size = getattr(hf_config, "image_size", 448) - target_width = image_size - target_height = image_size * 9 - return ImageSize(width=target_width, height=target_height) + return self.get_defaul_image_sizes( + self.get_max_slice_num() + ) def get_video_max_slice_num(self) -> int: return 1 def get_video_frame_size_with_most_features(self) -> ImageSize: - hf_config = self.get_hf_config() - image_size = getattr(hf_config, "image_size", 448) - return ImageSize(width=image_size, height=image_size) + return self.get_defaul_image_sizes( + self.get_video_max_slice_num() + ) - def _get_max_video_frames(self, max_tokens: int) -> int: + def get_max_video_frames(self, max_tokens: int) -> int: num_frame_tokens = self.get_max_video_frame_tokens() num_frames = max_tokens // num_frame_tokens return num_frames @@ -454,11 +481,15 @@ def get_num_frames_with_most_features(self, seq_len: int) -> int: mm_config = self.ctx.get_mm_config() max_images = mm_config.limit_per_prompt.get("image", 1) max_videos = mm_config.limit_per_prompt.get("video", 1) + max_audios = mm_config.limit_per_prompt.get("audio", 1) # count tokens which are not in get_max_image_tokens max_image_tokens = self.get_max_image_tokens() * max_images + 4 * max_images - max_total_frames = self._get_max_video_frames(seq_len - - max_image_tokens) + seq_len = seq_len - max_image_tokens + if "audio" in self.get_supported_mm_modalities(): + max_audio_tokens = self.get_max_audio_tokens() * max_audios + 2 * max_audios + seq_len = seq_len - max_audio_tokens + max_total_frames = self.get_max_video_frames(seq_len) num_frames = max(max_total_frames // max(max_videos, 1), 1) @@ -466,7 +497,26 @@ def get_num_frames_with_most_features(self, seq_len: int) -> int: def get_defaul_image_sizes(self, num_slices: int) -> ImageSize: image_size = getattr(self.get_hf_config(), "image_size", 448) - return ImageSize(image_size, image_size * num_slices) + return ImageSize(width=image_size, height=image_size * num_slices) + + def get_default_audio_pool_step(self): + return 2 + + def get_default_audio_sampling_rate(self): + return 16000 + + def get_chunk_length(self) -> int: + return self.get_hf_config().audio_chunk_length + + def get_max_audio_tokens_per_chunk(self) -> int: + pool_step = self.get_default_audio_pool_step() + fbank_feat_in_chunk = 100 + cnn_feat_in_chunk = (fbank_feat_in_chunk - 1) // 2 + 1 + num_audio_tokens = (cnn_feat_in_chunk - pool_step) // pool_step + 1 + return num_audio_tokens + 2 # + + def get_max_audio_chunks_with_most_features(self) -> int: + return 30 class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[MiniCPMVProcessingInfo]): @@ -479,10 +529,6 @@ def get_dummy_processor_inputs( num_videos = mm_counts.get("video", 0) num_audios = mm_counts.get("audio", 0) - # image_token: str = hf_processor.image_token - # video_token: str = hf_processor.video_token - # audio_token: str = hf_processor.audio_token - image_width, image_height = \ self.info.get_image_size_with_most_features() video_width, video_height = \ @@ -490,6 +536,8 @@ def get_dummy_processor_inputs( num_video_frames = \ self.info.get_num_frames_with_most_features(seq_len) + audio_len = self.info.get_max_audio_chunks_with_most_features() * \ + self.info.get_default_audio_sampling_rate() mm_data = { "image": self._get_dummy_images( @@ -501,14 +549,19 @@ def get_dummy_processor_inputs( width=video_width, height=video_height, num_images=num_video_frames - )] * num_videos + )] * num_videos, + "audio": self._get_dummy_audios( + length=audio_len, + num_audios=num_audios + ) } image_prompt_texts = self.info.image_pattern * num_images video_prompt_texts = self.info.video_pattern * num_videos + audio_prompt_texts = self.info.audio_pattern * num_audios return ProcessorInputs( - prompt_text=image_prompt_texts + video_prompt_texts, + prompt_text=image_prompt_texts + video_prompt_texts + audio_prompt_texts, mm_data=mm_data ) @@ -546,6 +599,13 @@ def get_video_prompt_texts(self, image_size: ImageSize, num_frames: int) -> str: ) for image_idx in range(num_frames) ]) return prompt_texts + + def get_audio_prompt_texts(self, audio_lens: int, chunk_input: bool = True, chunk_length: int = 1): + return self.info.get_hf_processor().get_audio_placeholder( + audio_lens, + chunk_input, + chunk_length + ) def get_special_tokens(self): tokenizer = self.info.get_tokenizer() @@ -556,6 +616,9 @@ def get_special_tokens(self): if hasattr(tokenizer, "slice_start_id"): special_tokens["slice_start_id"] = torch.tensor(tokenizer.slice_start_id) special_tokens["slice_end_id"] = torch.tensor(tokenizer.slice_end_id) + if hasattr(tokenizer, "audio_start_id"): + special_tokens["audio_start_id"] = torch.tensor(tokenizer.audio_start_id) + special_tokens["audio_end_id"] = torch.tensor(tokenizer.audio_end_id) return special_tokens @staticmethod @@ -566,7 +629,6 @@ def repack_processor_outputs(outputs: Any) -> BatchFeature: def process_images( self, - prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object] ) -> Dict[str, object]: @@ -574,7 +636,7 @@ def process_images( image_embeds = mm_data.pop("image_embeds", []) if isinstance(images, (list, torch.Tensor)) and len(images) > 0: image_outputs = super()._call_hf_processor( - prompt=prompt, + prompt=self.info.image_pattern * len(images), mm_data={"images": images}, mm_kwargs=mm_kwargs ) @@ -632,7 +694,47 @@ def process_videos( else: video_outputs = {} return video_outputs - + + def process_audios( + self, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object] + ): + audios = mm_data.pop("audios", []) + audio_embeds = mm_data.pop("audio_embeds", []) + if isinstance(audios, (list, torch.Tensor)) and len(audios) > 0: + audio_outputs = { + "audio_lens": [], + "audio_features": [], + "audio_feature_lens": [], + "audio_num_segments": [] + } + for audio in audios: + single_audio_outputs = super()._call_hf_processor( + prompt=self.info.audio_pattern, + mm_data={ + "audios": audio, + "chunk_input": True + }, + mm_kwargs=mm_kwargs + ) + audio_outputs["audio_lens"].append(len(audio)) + audio_outputs["audio_features"].append( + single_audio_outputs["audio_features"] + ) + audio_outputs["audio_num_segments"].append( + len(single_audio_outputs["audio_feature_lens"]) + ) + audio_outputs["audio_feature_lens"] += \ + single_audio_outputs["audio_feature_lens"] + audio_outputs["audio_features"] = torch.cat(audio_outputs["audio_features"]) + audio_outputs["audio_feature_lens"] = torch.cat(audio_outputs["audio_feature_lens"]) + elif len(audio_embeds): + pass + else: + audio_outputs = {} + return audio_outputs + def _call_hf_processor( self, prompt: str, @@ -643,23 +745,23 @@ def _call_hf_processor( # Try to handle interleaved multimodal data tokenizer = self.info.get_tokenizer() - audios = mm_data.pop("audios", []) - - image_outputs = self.process_images(prompt, mm_data, mm_kwargs) + image_outputs = self.process_images(mm_data, mm_kwargs) video_outputs = self.process_videos(mm_data, mm_kwargs) - audio_outputs = {} + audio_outputs = self.process_audios(mm_data, mm_kwargs) counts = { "image": 0, "video": 0, "audio": 0 } - image_orders_in_mm_data = [] num_image_slices = [] num_video_slices = [] + num_audio_slices = [] video_orders_in_mm_data = [] - matches = re.findall(r"\(<(image|video)>./\)", prompt) - chunks = re.split(r"\(<(?:image|video)>./\)", prompt) + image_orders_in_mm_data = [] + audio_orders_in_mm_data = [] + matches = re.findall(r"\(<(image|video|audio)>./\)", prompt) + chunks = re.split(r"\(<(?:image|video|audio)>./\)", prompt) new_prompt = chunks[0] for idx, item in enumerate(matches): if item == "image": @@ -672,7 +774,7 @@ def _call_hf_processor( image_outputs["image_sizes"][counts[item]], counts[item] ) - else: + elif item == "video": video_orders_in_mm_data.append(idx) num_video_slices.append(self.info.get_image_slice_nums( video_outputs["video_image_sizes"][counts[item]], @@ -682,6 +784,15 @@ def _call_hf_processor( video_outputs["video_image_sizes"][counts[item]], video_outputs["num_frames"][counts[item]] ) + else: # audio + audio_orders_in_mm_data.append(idx) + num_audio_slices.append( + audio_outputs["audio_num_segments"][counts[item]] + ) + new_prompt += self.get_audio_prompt_texts( + audio_outputs["audio_lens"][counts[item]] + ) + counts[item] += 1 new_prompt += chunks[idx + 1] @@ -702,7 +813,10 @@ def get_slices(num_slices: List[int]): "image_slices": get_slices(num_image_slices), **video_outputs, "video_orders_in_mm_data": video_orders_in_mm_data, - "video_slices": get_slices(num_video_slices) + "video_slices": get_slices(num_video_slices), + **audio_outputs, + "audio_orders_in_mm_data": audio_orders_in_mm_data, + "audio_slices": get_slices(num_audio_slices), } def _get_prompt_replacements( @@ -711,10 +825,6 @@ def _get_prompt_replacements( hf_processor_mm_kwargs: Mapping[str, Any], out_mm_kwargs: MultiModalKwargs ) -> List[PromptReplacement]: - image_processor = self.info.get_image_processor( - **hf_processor_mm_kwargs - ) - placeholder = { "image": self.info.image_pattern, "video": self.info.video_pattern, @@ -732,6 +842,10 @@ def get_replacement_minicpmv(item_idx: int, modality: str): mm_items["video"].get_frame_size(item_idx), mm_items["video"].get_num_frames(item_idx) ) + else: # audio + return self.get_audio_prompt_texts( + len(mm_items["audio"].get(item_idx)) + ) return [ PromptReplacement( @@ -739,7 +853,7 @@ def get_replacement_minicpmv(item_idx: int, modality: str): target=placeholder[modality], replacement=partial(get_replacement_minicpmv, modality=modality) - ) for modality in ("image", "video") + ) for modality in ("image", "video", "audio") ] def _get_mm_fields_config( @@ -751,6 +865,7 @@ def get_slices(slices_indices: List[int]): return [slice(*slice_item) for slice_item in slices_indices] image_slices = get_slices(hf_inputs.get("image_slices", torch.empty(0, 2))) video_slices = get_slices(hf_inputs.get("video_slices", torch.empty(0, 2))) + audio_slices = get_slices(hf_inputs.get("audio_slices", torch.empty(0, 2))) return dict( pixel_values=MultiModalFieldConfig.flat("image", image_slices), image_sizes=MultiModalFieldConfig.batched("image"), @@ -764,6 +879,11 @@ def get_slices(slices_indices: List[int]): video_orders_in_mm_data=MultiModalFieldConfig.batched("video"), video_embeds=MultiModalFieldConfig.flat("video", video_slices), video_slices=MultiModalFieldConfig.batched("video"), + audio_features=MultiModalFieldConfig.flat("audio", audio_slices), + audio_feature_lens=MultiModalFieldConfig.flat("audio", audio_slices), + audio_num_segments=MultiModalFieldConfig.batched("audio"), + audio_slices=MultiModalFieldConfig.batched("audio"), + audio_orders_in_mm_data=MultiModalFieldConfig.batched("audio") ) def apply( @@ -775,7 +895,8 @@ def apply( result = super().apply(prompt, mm_data, hf_processor_mm_kwargs) # Exclude x from placeholders - if "image" in result["mm_placeholders"] and self.info.get_model_version() == (2, 6): + if "image" in result["mm_placeholders"] and \ + self.info.get_model_version() in [(2, 6), (2, "6O")]: result["mm_placeholders"]["image"] = [ PlaceholderRange( offset=p["offset"] + 3 + idx // 10, @@ -833,7 +954,7 @@ def sampler(self): return get_sampler() - def get_embedding( + def get_embedding_with_vision( self, input_ids: torch.Tensor, image_inputs: Optional[MiniCPMVImageInputs], @@ -867,6 +988,41 @@ def get_embedding( return vlm_embedding, vision_hidden_states + def get_embedding_with_audios( + self, + input_ids: torch.Tensor, + vlm_embedding: torch.Tensor, + audio_inputs: Optional[MiniCPMVAudioInputs], + chunk_length: int + ) -> torch.Tensor: + device, dtype = vlm_embedding.device, vlm_embedding.dtype + audio_embeddings = self.get_audio_hidden_states(audio_inputs, chunk_length)[0] + audio_bounds = audio_inputs["audio_bounds"] + if self.config.chunk_input: + audio_embs = torch.cat(audio_embeddings, dim=0).to( + device=device, dtype=dtype + ) + audio_start_pos = 0 + for bound in audio_bounds: + audio_len = bound[1] - bound[0] + vlm_embedding[bound[0] : bound[1]] = audio_embs[ + audio_start_pos : audio_start_pos + audio_len, : + ] + audio_start_pos += audio_len + else: + for embs, bound in zip(audio_embeddings, audio_bounds): + audio_indices = torch.arange(bound[0], bound[1], dtype=torch.long).to( + device + ) + + if embs.shape[0] != len(audio_indices): + raise ValueError( + f"Shape mismatch: Trying to assign embeddings of shape {embs.shape} " + f"to input indices of length {len(audio_indices)}" + ) + vlm_embedding[audio_indices] = embs.to(dtype) + return vlm_embedding + def _get_image_bounds( self, input_ids: torch.Tensor, @@ -895,21 +1051,36 @@ def _get_image_bounds( image_end_tokens[:valid_image_nums].unsqueeze(-1), ]) - def _parse_and_validate_inputs( + def _get_audio_bounds( + self, + input_ids: torch.Tensor, + audio_start_id: torch.Tensor, + audio_end_id: torch.Tensor + ) -> torch.Tensor: + audio_start_tokens, = torch.where(input_ids == audio_start_id[0]) + audio_start_tokens += 1 + audio_end_tokens, = torch.where(input_ids == audio_end_id[0]) + valid_audio_nums = max(len(audio_start_tokens), len(audio_end_tokens)) + return torch.hstack([ + audio_start_tokens[:valid_audio_nums].unsqueeze(-1), + audio_end_tokens[:valid_audio_nums].unsqueeze(-1) + ]) + + def _parse_and_validate_image_inputs( self, input_ids: torch.Tensor, **kwargs: object, ) -> Optional[MiniCPMVImageInputs]: mm_data = { "image": { - "pixel_values": kwargs.pop("pixel_values", []), - "tgt_sizes": kwargs.pop("tgt_sizes", []), - "image_slices": kwargs.pop("image_slices", []), + key: kwargs.pop(key, []) for key in [ + "pixel_values", "tgt_sizes", "image_slices" + ] }, "video": { "pixel_values": kwargs.pop("video_pixel_values", []), "tgt_sizes": kwargs.pop("video_tgt_sizes", []), - "video_slices": kwargs.pop("video_slices", []), + "video_slices": kwargs.pop("video_slices", []) } } im_start_id = kwargs.pop("im_start_id", None) @@ -918,7 +1089,7 @@ def _parse_and_validate_inputs( slice_end_id = kwargs.pop("slice_end_id", None) orders_in_mm_data = { modality: kwargs.pop(f"{modality}_orders_in_mm_data", None) - for modality in ["image", "video"] + for modality in ["image", "video", "audio"] } batch_size = max(len(mm_data["image"]["pixel_values"]), len(mm_data["video"]["pixel_values"])) @@ -964,9 +1135,9 @@ def _parse_and_validate_inputs( for modality, orders in orders_in_mm_data.items() } mm_data_indices = [ - (index, (pos, "image")) for pos, index in enumerate(orders_in_mm_data_b["image"]) - ] + [ - (index, (pos, "video")) for pos, index in enumerate(orders_in_mm_data_b["video"]) + (index, (pos, media_type)) + for media_type in ["image", "video", "audio"] + for pos, index in enumerate(orders_in_mm_data_b[media_type]) ] mm_data_indices = [ (pos, modality) for index, (pos, modality) in @@ -977,7 +1148,7 @@ def _parse_and_validate_inputs( slice_index = mm_data[modality]["image_slices"][b][pos] pixel_values_flat += mm_data[modality]["pixel_values"][b][slice_index[0]:slice_index[1]] tgt_sizes_flat += mm_data[modality]["tgt_sizes"][b][slice_index[0]:slice_index[1]] - else: + elif modality == "video": slice_index = mm_data[modality]["video_slices"][b][pos] pixel_values_flat += mm_data[modality]["pixel_values"][b][slice_index[0]:slice_index[1]] tgt_sizes_flat += mm_data[modality]["tgt_sizes"][b][slice_index[0]:slice_index[1]] @@ -1004,6 +1175,49 @@ def _parse_and_validate_inputs( type="pixel_values", ) + def _parse_and_validate_audio_inputs( + self, + input_ids: torch.Tensor, + **kwargs: object + ) -> Tuple[MiniCPMVImageInputs]: + audio_features = kwargs.pop("audio_features", []) + audio_feature_lens = kwargs.pop("audio_feature_lens", []) + if len(audio_features) > 0: + audio_features = torch.cat([ + item for item in audio_features + ]) + audio_feature_lens = torch.cat([ + item for item in audio_feature_lens + ]) + audio_start_id = kwargs.pop("audio_start_id", None) + audio_end_id = kwargs.pop("audio_end_id", None) + + return MiniCPMVAudioFeatureInputs( + audio_bounds=self._get_audio_bounds(input_ids, + audio_start_id, audio_end_id), + data=audio_features, + audio_feature_lens=audio_feature_lens, + type="audio_features" + ) + return None + + def _parse_and_validate_inputs( + self, + input_ids: torch.Tensor, + **kwargs: object + ): + image_inputs = self._parse_and_validate_image_inputs( + input_ids, + **kwargs + ) + if not any("audio" in key for key in kwargs.keys()): + return image_inputs, None + audio_inputs = self._parse_and_validate_audio_inputs( + input_ids, + **kwargs + ) + return image_inputs, audio_inputs + def forward( self, input_ids: torch.Tensor, @@ -1016,9 +1230,17 @@ def forward( if intermediate_tensors is not None: vlm_embeddings = None else: - image_inputs = self._parse_and_validate_inputs(input_ids, **kwargs) - - vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs) + image_inputs, audio_inputs = \ + self._parse_and_validate_inputs(input_ids, **kwargs) + vlm_embeddings, _ = self.get_embedding_with_vision(input_ids, image_inputs) + + if audio_inputs is not None: + vlm_embeddings = self.get_embedding_with_audios( + input_ids, + vlm_embeddings, + audio_inputs, + self.config.audio_chunk_length + ) # always pass the input via `inputs_embeds` # to make sure the computation graph is consistent @@ -1052,8 +1274,7 @@ def sample( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - loader = AutoWeightsLoader(self, - skip_prefixes=["apm", "tts", "audio_projection_layer"]) + loader = AutoWeightsLoader(self) return loader.load_weights(weights) def get_mm_mapping(self) -> MultiModelKeys: @@ -1098,6 +1319,11 @@ def get_vision_hidden_states(self, data: MiniCPMVImageInputs) -> torch.Tensor: raise NotImplementedError + def get_audio_hidden_states(self, + data: MiniCPMVAudioInputs, + chunk_length: int) -> torch.Tensor: + raise NotImplementedError + class MiniCPMV2_0(MiniCPMVBaseModel): @@ -1337,7 +1563,7 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) - assert self.version == (2, 6) + assert self.version in [(2, 6), (2, "6O")] def init_llm( self, @@ -1422,10 +1648,521 @@ def get_vision_hidden_states(self, return self.resampler(vision_embedding, tgt_sizes) +class MultiModalProjector(nn.Module): + def __init__(self, in_dim, out_dim): + super().__init__() + self.linear1 = nn.Linear(in_features=in_dim, out_features=out_dim, bias=True) + self.relu = nn.ReLU() + self.linear2 = nn.Linear(in_features=out_dim, out_features=out_dim, bias=True) + + def forward(self, audio_features): + hidden_states = self.relu(self.linear1(audio_features)) + hidden_states = self.linear2(hidden_states) + return hidden_states + + +# Copied from transformers.models.whisper.modeling_whisper.WhisperEncoderLayer and add use_cache for streaming inference +class MiniCPMWhisperEncoderLayer(nn.Module): + def __init__(self, config: WhisperConfig, layer_idx: int = None): + super().__init__() + self.embed_dim = config.d_model + self.self_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation]( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + config=config, + layer_idx=layer_idx, + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + layer_head_mask: torch.Tensor, + output_attentions: bool = False, + past_key_values: Optional[EncoderDecoderCache] = None, + use_cache: Optional[bool] = False, + ) -> torch.Tensor: + r""" + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, embed_dim)`): + Hidden states to be fed into the encoder layer. + attention_mask (`torch.FloatTensor` of shape `(batch_size, 1, tgt_len, src_len)`): + Attention mask where padding elements are indicated by large negative values. + layer_head_mask (`torch.FloatTensor` of shape `(encoder_attention_heads,)`): + Mask to nullify selected heads of the attention modules. + output_attentions (`bool`, *optional*): + Whether or not to return the attention weights. + past_key_values (`EncoderDecoderCache`, *optional*): + Past key-value pairs used for incremental decoding. + use_cache (`bool`, *optional*): + Whether or not to return updated `past_key_values` for caching. + + Returns: + A tuple of shape `(hidden_states, optional(attn_weights), optional(past_key_values))`. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weights, past_key_values = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + past_key_value=past_key_values, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + ): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + if use_cache: + outputs += (past_key_values,) + + return outputs + + +# Copied from from transformers.models.whisper.modeling_whisper.WhisperEncoder and add use_cache for streaming inference +class MiniCPMWhisperEncoder(WhisperEncoder): + + def __init__(self, config: WhisperConfig): + super().__init__(config) + self.layers = nn.ModuleList( + [MiniCPMWhisperEncoderLayer(config, layer_idx=i) for i in range(config.encoder_layers)] + ) + + def forward( + self, + input_features, + attention_mask=None, + head_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + past_key_values: Optional[EncoderDecoderCache] = None, + use_cache: Optional[bool] = None, + ): + r""" + Forward pass of the Whisper encoder. + + Args: + input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, sequence_length)`): + Float values of log-mel features extracted from the raw audio waveform. Typically generated + by a feature extractor (e.g., `WhisperFeatureExtractor`) that processes `.flac` or `.wav` + files into padded 2D mel spectrogram frames. These features are projected via convolution layers + (`conv1` and `conv2`) and then transformed into embeddings for the encoder. + + attention_mask (`torch.Tensor`, *optional*): + Not used by Whisper for masking `input_features`, but included for API compatibility with + other models. If provided, it is simply ignored within the model. By default, Whisper + effectively ignores silence in the input log-mel spectrogram. + + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected attention heads. The elements should be either 1 or 0, where: + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked** (i.e., the attention head is dropped). + + output_attentions (`bool`, *optional*): + Whether or not to return the attention tensors of all encoder layers. If set to `True`, the + returned tuple (or `BaseModelOutputWithPast`) will contain an additional element with + attention weights for each encoder layer. + + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. If set to `True`, the returned + tuple (or `BaseModelOutputWithPast`) will contain a tuple of hidden states, including the + initial embedding output as well as the outputs of each layer. + + return_dict (`bool`, *optional*): + Whether or not to return a `BaseModelOutputWithPast` (a subclass of `ModelOutput`) instead + of a plain tuple. If set to `True`, the output will be a `BaseModelOutputWithPast` object, + otherwise it will be a tuple. + + past_key_values (`EncoderDecoderCache`, *optional*): + When using caching for faster inference, this is an object that stores the key-value pairs + for attention states. If provided, the model will append new states to the existing cache + and return the updated cache. This speeds up sequential decoding or chunked inference. + + - If `past_key_values` is `None`, no past states are used or returned. + - If `past_key_values` is not `None` and `use_cache=True`, the model will use the provided + cache and return the updated cache (as `next_encoder_cache`). + + use_cache (`bool`, *optional*): + Whether or not the model should use caching (`past_key_values`) to speed up processing + during inference. When set to `True`, the model will: + - Inspect and use `past_key_values` if provided. + - Return updated `past_key_values` (under the name `next_encoder_cache` in + `BaseModelOutputWithPast`). + + Returns: + `BaseModelOutputWithPast` or `tuple` (depending on `return_dict`): + If `return_dict=True`, a `BaseModelOutputWithPast` is returned, which contains: + - **last_hidden_state** (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + The output of the final encoder layer. + - **hidden_states** (`tuple(torch.FloatTensor)`, *optional*, returned if `output_hidden_states=True`): + Hidden states of the model at each layer (including the initial projection). + - **attentions** (`tuple(torch.FloatTensor)`, *optional*, returned if `output_attentions=True`): + Attention weights from each encoder layer. + - **past_key_values** (an object of type `EncoderDecoderCache` or `None`, *optional*): + Updated cache of key-value pairs if `use_cache=True`. + + If `return_dict=False`, a tuple is returned, where the format is: + `(last_hidden_state, hidden_states, attentions)`, with `hidden_states` and `attentions` + only present if their respective `output_*` arguments are set to `True`. + + Example: + >>> from transformers import AutoFeatureExtractor, WhisperConfig, WhisperForConditionalGeneration + >>> import torch + + >>> # Load a feature extractor and a Whisper model + >>> feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-tiny.en") + >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") + + >>> # Assume you have audio (list of floats or numpy array) loaded from a file + >>> # Then extract the mel features: + >>> input_features = feature_extractor(audio, sampling_rate=16000, return_tensors="pt").input_features + + >>> # Forward pass + >>> outputs = model.encoder( + ... input_features=input_features, + ... output_hidden_states=True, + ... output_attentions=True, + ... use_cache=True + ... ) + + >>> # Retrieve the last hidden state + >>> last_hidden_state = outputs.last_hidden_state + >>> print(last_hidden_state.shape) + torch.Size([batch_size, seq_length, hidden_size]) + + >>> # Retrieve the intermediate hidden states if output_hidden_states=True + >>> all_encoder_hidden_states = outputs.hidden_states + + >>> # Retrieve attention weights if output_attentions=True + >>> all_encoder_attentions = outputs.attentions + + >>> # Retrieve updated past key values if use_cache=True + >>> encoder_cache = outputs.past_key_values + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Ignore copy + input_features = input_features.to(dtype=self.conv1.weight.dtype, device=self.conv1.weight.device) + + inputs_embeds = nn.functional.gelu(self.conv1(input_features)) + inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) + + inputs_embeds = inputs_embeds.permute(0, 2, 1) + + embed_pos = self.embed_positions.weight + past_key_values_length = 0 + if use_cache: + if past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + elif isinstance(past_key_values, list): + past_key_values = EncoderDecoderCache(DynamicCache.from_legacy_cache(past_key_values), DynamicCache()) + elif isinstance(past_key_values, DynamicCache): + past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) + else: + pass + past_key_values_length = past_key_values.self_attention_cache.get_usable_length(inputs_embeds.shape[1]) + if inputs_embeds.shape[1] + past_key_values_length > embed_pos.shape[0]: + # logger.warning("seems the audio is longer than 30s. repeating the last part of the audio") + embed_pos_front = embed_pos[past_key_values_length:, :] + embed_pos = torch.cat( + ( + embed_pos_front, + torch.repeat_interleave( + embed_pos[-1, :].unsqueeze(0), + inputs_embeds.shape[1] - embed_pos.shape[0] + past_key_values_length, + dim=0, + ), + ) + ) + else: + embed_pos = embed_pos[past_key_values_length : inputs_embeds.shape[1] + past_key_values_length, :] + else: + embed_pos = embed_pos[: inputs_embeds.shape[1], :] + + hidden_states = inputs_embeds + embed_pos + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + assert head_mask.size()[0] == ( + len(self.layers) + ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + to_drop = False + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: # skip the layer + to_drop = True + + # Ignore copy + if to_drop: + layer_outputs = (None, None) + else: + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + (head_mask[idx] if head_mask is not None else None), + output_attentions, + past_key_values, + use_cache, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + past_key_values=past_key_values, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_encoder_cache = layer_outputs[2 if output_attentions else 1] + else: + next_encoder_cache = None + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + hidden_states = self.layer_norm(hidden_states) + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + hidden_states=encoder_states, + attentions=all_attentions, + past_key_values=next_encoder_cache, + ) + + +class MiniCPMO2_6(MiniCPMV2_6): + # apm_mapper_hf_pattern = ["apm.layers.{}.fc1", "apm.layers.{}.fc2"] + # apm_mapper_vllm_pattern = ["apm.layers.{}.mlp.fc1", "apm.layers.{}.mlp.fc2"] + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str=""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + self.apm = self.init_audio_module(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "apm")) + + def init_audio_module( + self, + *, + vllm_config: VllmConfig, + prefix: str="" + ): + # Do not use parameters temporarily + audio_config = self.config.audio_config + model = MiniCPMWhisperEncoder(audio_config) + audio_output_dim = int(audio_config.encoder_ffn_dim // 4) + setattr(self, "audio_avg_pooler", + nn.AvgPool1d(self.config.audio_pool_step, stride=self.config.audio_pool_step)) + setattr(self, "audio_projection_layer", + MultiModalProjector(in_dim=audio_output_dim, out_dim=self.embed_dim)) + setattr(self, "audio_encoder_layer", -1) + return model + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader(self, + skip_prefixes=["tts"]) + return loader.load_weights(weights) + + def subsequent_chunk_mask( + self, + size: int, + chunk_size: int, + num_left_chunks: int = -1, + device: torch.device = torch.device("cpu"), + num_lookhead: int = 0, + ) -> torch.Tensor: + """Create mask for subsequent steps (size, size) with chunk size, + this is for streaming encoder + + Args: + size (int): size of mask + chunk_size (int): size of chunk + num_left_chunks (int): number of left chunks + <0: use full chunk + >=0: use num_left_chunks + device (torch.device): "cpu" or "cuda" or torch.Tensor.device + + Returns: + torch.Tensor: mask + + Examples: + >>> subsequent_chunk_mask(4, 2) + [[1, 1, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 1], + [1, 1, 1, 1]] + """ + ret = torch.zeros(size, size, device=device, dtype=torch.bool) + for i in range(size): + if num_left_chunks < 0: + start = 0 + else: + start = max((i // chunk_size - num_left_chunks) * chunk_size, 0) + ending = min((i // chunk_size + 1) * chunk_size + num_lookhead, size) + ret[i, start:ending] = True + return ret + + def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): + """ + Computes the output length of the convolutional layers and the output length of the audio encoder + """ + input_lengths_after_cnn = (input_lengths - 1) // 2 + 1 + input_lengths_after_pooling = ( + input_lengths_after_cnn - self.config.audio_pool_step + ) // self.config.audio_pool_step + 1 + input_lengths_after_pooling = input_lengths_after_pooling.to(dtype=torch.int32) + + return input_lengths_after_cnn, input_lengths_after_pooling + + # Copied from HF repo of MiniCPM-o-2_6, designed for batched inputs and outputs + def get_audio_hidden_states(self, + data: MiniCPMVAudioInputs, + chunk_length: int) -> torch.Tensor: + r""" + Extract full audio embeddings with optional chunk-based attention. + + This method computes embeddings for all audio frames at once, either using full attention (when + `chunk_length` is -1) or chunk-based attention (when `chunk_length` is a positive number). It does + not use key-value caching and is suitable for non-streaming inference. + + Args: + data (dict): + - **"audio_features"** (`torch.FloatTensor`): Input mel-spectrograms of shape `(batch_size, 80, frames)`. + - **"audio_feature_lens"** (List[int]): Lengths of each audio segment for each item. + chunk_length (int, optional): Determines whether to use full attention (-1) or chunk-based + attention (>0) during embedding computation. + + Returns: + List[List[torch.Tensor]]: audio embeddings + """ + wavforms = data.get("data", []) # (bs, 80, frames) or [], multi audios need filled in advance + audio_feature_lens_raw = [data.get("audio_feature_lens", [])] # list, [[x1, x2], [y1], [z1]] + + # exist audio + if len(wavforms) > 0: + audio_feature_lens = torch.hstack(audio_feature_lens_raw) + batch_size, _, max_mel_seq_len = wavforms.shape + max_seq_len = (max_mel_seq_len - 1) // 2 + 1 + + # Create a sequence tensor of shape (batch_size, max_seq_len) + seq_range = ( + torch.arange(0, max_seq_len, dtype=audio_feature_lens.dtype, device=audio_feature_lens.device) + .unsqueeze(0) + .expand(batch_size, max_seq_len) + ) + lengths_expand = audio_feature_lens.unsqueeze(1).expand(batch_size, max_seq_len) + # Create mask + padding_mask = seq_range >= lengths_expand # 1 for padded values + + audio_attention_mask_ = padding_mask.view(batch_size, 1, 1, max_seq_len).expand( + batch_size, 1, max_seq_len, max_seq_len + ) + audio_attention_mask = audio_attention_mask_.to( + dtype=self.apm.conv1.weight.dtype, device=self.apm.conv1.weight.device + ) + + if chunk_length > 0: + chunk_num_frame = int(chunk_length * 50) + chunk_mask = self.subsequent_chunk_mask( + size=max_seq_len, + chunk_size=chunk_num_frame, + num_left_chunks=-1, + device=audio_attention_mask_.device, + ) + audio_attention_mask_ = torch.logical_or(audio_attention_mask_, torch.logical_not(chunk_mask)) + + audio_attention_mask[audio_attention_mask_] = float("-inf") + audio_states = self.apm( + wavforms, output_hidden_states=True, attention_mask=audio_attention_mask + ).hidden_states[self.audio_encoder_layer] + audio_embeds = self.audio_projection_layer(audio_states) + + audio_embeds = audio_embeds.transpose(1, 2) + audio_embeds = self.audio_avg_pooler(audio_embeds) + audio_embeds = audio_embeds.transpose(1, 2) + + _, feature_lens_after_pooling = self._get_feat_extract_output_lengths(audio_feature_lens) + + num_audio_tokens = feature_lens_after_pooling + + final_audio_embeds = [] + idx = 0 + for i in range(len(audio_feature_lens_raw)): + target_audio_embeds = [] + for _ in range(len(audio_feature_lens_raw[i])): + target_audio_embeds.append(audio_embeds[idx, : num_audio_tokens[idx], :]) + idx += 1 + final_audio_embeds.append(target_audio_embeds) + return final_audio_embeds + else: + return [] + + _SUPPORT_VERSION = { (2, 0): MiniCPMV2_0, (2, 5): MiniCPMV2_5, - (2, 6): MiniCPMV2_6 + (2, 6): MiniCPMV2_6, + (2, "6O"): MiniCPMO2_6, } @@ -1452,6 +2189,8 @@ def __new__(cls, *, vllm_config: VllmConfig, prefix: str = ""): version = (2, 0) else: version = (2, 5) + elif "MiniCPMO" in config.architectures: + version = (2, "6O") else: version = str(config.version).split(".") version = tuple([int(x) for x in version])