From 1ba77eb46d30a1664f4f3d1c315e4e39e7e7fa38 Mon Sep 17 00:00:00 2001 From: hzh Date: Wed, 22 Jan 2025 13:12:14 +0000 Subject: [PATCH] audio embedding inputs Signed-off-by: hzh --- vllm/model_executor/models/minicpmv.py | 110 +++++++++++++++++++++---- 1 file changed, 93 insertions(+), 17 deletions(-) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index e23f2fafae08a..89ab5d02995be 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -140,9 +140,22 @@ class MiniCPMVAudioFeatureInputs(TypedDict): """ +class MiniCPMVAudioEmbeddingInputs(TypedDict): + type: Literal["audio_embeds"] + data: torch.Tensor + """ + Shape: + """ + audio_bounds: torch.Tensor + """ + Shape: + """ + + MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs, MiniCPMVImageEmbeddingInputs] -MiniCPMVAudioInputs = Union[MiniCPMVAudioFeatureInputs] +MiniCPMVAudioInputs = Union[MiniCPMVAudioFeatureInputs, + MiniCPMVAudioEmbeddingInputs] class MiniCPMVEmbeddingItems(ModalityDataItems[dict[str, torch.Tensor], @@ -227,6 +240,19 @@ def get_num_frames(self, index: int) -> int: return self.data["num_frames"][index] +class MiniCPMVAudioEmbeddingItems(MiniCPMVEmbeddingItems): + def __init__(self, data: Dict) -> None: + super().__init__(data, "audio") + audio_embeds = self.data.get("audio_embeds", None) + if audio_embeds is None: + raise ValueError(f"In correct type of video_embeds", + f"Got type: None") + self.data["audio_embeds"] = audio_embeds + + def get(self, index: int) -> object: + return self.data["audio_embeds"][index] + + DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6) @@ -346,7 +372,6 @@ def _parse_image_data( ) -> ModalityDataItems[Any, Any]: if isinstance(data, dict): return MiniCPMVImageEmbeddingItems(data) - return super()._parse_image_data(data) def _parse_video_data( @@ -355,9 +380,16 @@ def _parse_video_data( ) -> ModalityDataItems[Any, Any]: if isinstance(data, dict): return MiniCPMVVideoEmbeddingItems(data) - return super()._parse_video_data(data) + def _parse_audio_data( + self, + data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]], + ) -> ModalityDataItems[Any, Any]: + if isinstance(data, dict): + return MiniCPMVAudioEmbeddingItems(data) + return super()._parse_audio_data(data) + class MiniCPMVProcessingInfo(BaseProcessingInfo): image_pattern = "(./)" @@ -507,7 +539,7 @@ def get_default_audio_sampling_rate(self): def get_chunk_length(self) -> int: return self.get_hf_config().audio_chunk_length - + def get_max_audio_tokens_per_chunk(self) -> int: pool_step = self.get_default_audio_pool_step() fbank_feat_in_chunk = 100 @@ -518,6 +550,12 @@ def get_max_audio_tokens_per_chunk(self) -> int: def get_max_audio_chunks_with_most_features(self) -> int: return 30 + def get_audio_len_by_num_chunks(self, num_chunks: int) -> int: + sampling_rate = self.get_default_audio_sampling_rate() + # exclude + num_tokens_per_chunk = self.get_max_audio_tokens_per_chunk() - 2 + return int(num_chunks * sampling_rate / num_tokens_per_chunk) + 1 + class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[MiniCPMVProcessingInfo]): def get_dummy_processor_inputs( @@ -723,14 +761,30 @@ def process_audios( single_audio_outputs["audio_features"] ) audio_outputs["audio_num_segments"].append( - len(single_audio_outputs["audio_feature_lens"]) + len(single_audio_outputs["audio_feature_lens"][0]) ) audio_outputs["audio_feature_lens"] += \ single_audio_outputs["audio_feature_lens"] audio_outputs["audio_features"] = torch.cat(audio_outputs["audio_features"]) audio_outputs["audio_feature_lens"] = torch.cat(audio_outputs["audio_feature_lens"]) elif len(audio_embeds): - pass + audio_outputs = { + "audio_lens": [ + self.info.get_audio_len_by_num_chunks( + sum(chunk_embeds.shape[0] for chunk_embeds in single_audio_embeds) + ) + for single_audio_embeds in audio_embeds + ], + "audio_embeds": [ + chunk_embeds + for single_audio_embeds in audio_embeds + for chunk_embeds in single_audio_embeds + ], + "audio_num_segments": [ + len(single_audio_embeds) + for single_audio_embeds in audio_embeds + ] + } else: audio_outputs = {} return audio_outputs @@ -744,11 +798,9 @@ def _call_hf_processor( # Do not support combination inputs of images and videos for now # Try to handle interleaved multimodal data tokenizer = self.info.get_tokenizer() - image_outputs = self.process_images(mm_data, mm_kwargs) video_outputs = self.process_videos(mm_data, mm_kwargs) audio_outputs = self.process_audios(mm_data, mm_kwargs) - counts = { "image": 0, "video": 0, @@ -809,12 +861,12 @@ def get_slices(num_slices: List[int]): return { "input_ids": np.array([input_ids]), **image_outputs, + **video_outputs, + **audio_outputs, "image_orders_in_mm_data": image_orders_in_mm_data, "image_slices": get_slices(num_image_slices), - **video_outputs, "video_orders_in_mm_data": video_orders_in_mm_data, "video_slices": get_slices(num_video_slices), - **audio_outputs, "audio_orders_in_mm_data": audio_orders_in_mm_data, "audio_slices": get_slices(num_audio_slices), } @@ -843,6 +895,12 @@ def get_replacement_minicpmv(item_idx: int, modality: str): mm_items["video"].get_num_frames(item_idx) ) else: # audio + if isinstance(mm_items["audio"], MiniCPMVAudioEmbeddingItems): + single_audio_embeds = mm_items["audio"].get(item_idx) + audio_len = self.info.get_audio_len_by_num_chunks( + sum(chunk_embeds.shape[0] for chunk_embeds in single_audio_embeds) + ) + return self.get_audio_prompt_texts(audio_len) return self.get_audio_prompt_texts( len(mm_items["audio"].get(item_idx)) ) @@ -881,9 +939,9 @@ def get_slices(slices_indices: List[int]): video_slices=MultiModalFieldConfig.batched("video"), audio_features=MultiModalFieldConfig.flat("audio", audio_slices), audio_feature_lens=MultiModalFieldConfig.flat("audio", audio_slices), - audio_num_segments=MultiModalFieldConfig.batched("audio"), audio_slices=MultiModalFieldConfig.batched("audio"), - audio_orders_in_mm_data=MultiModalFieldConfig.batched("audio") + audio_orders_in_mm_data=MultiModalFieldConfig.batched("audio"), + audio_embeds=MultiModalFieldConfig.flat("audio", audio_slices) ) def apply( @@ -990,13 +1048,21 @@ def get_embedding_with_vision( def get_embedding_with_audios( self, - input_ids: torch.Tensor, vlm_embedding: torch.Tensor, audio_inputs: Optional[MiniCPMVAudioInputs], chunk_length: int ) -> torch.Tensor: device, dtype = vlm_embedding.device, vlm_embedding.dtype - audio_embeddings = self.get_audio_hidden_states(audio_inputs, chunk_length)[0] + if audio_inputs["type"] == "audio_embeds": + audio_embeddings = audio_inputs["data"] + audio_embeddings = [ + audio_embeddings[i].to(device=device, dtype=dtype) + for i in range(len(audio_embeddings)) + ] + else: + audio_embeddings = self.get_audio_hidden_states(audio_inputs, chunk_length)[0] + if audio_embeddings is None or len(audio_embeddings) == 0: + return vlm_embedding audio_bounds = audio_inputs["audio_bounds"] if self.config.chunk_input: audio_embs = torch.cat(audio_embeddings, dim=0).to( @@ -1182,6 +1248,19 @@ def _parse_and_validate_audio_inputs( ) -> Tuple[MiniCPMVImageInputs]: audio_features = kwargs.pop("audio_features", []) audio_feature_lens = kwargs.pop("audio_feature_lens", []) + audio_embeds = kwargs.pop("audio_embeds", None) + audio_start_id = kwargs.pop("audio_start_id", None) + audio_end_id = kwargs.pop("audio_end_id", None) + if audio_embeds is not None: + audio_embeds = [audio_embeds[i][j] + for i in range(len(audio_embeds)) + for j in range(len(audio_embeds[i]))] + return MiniCPMVAudioEmbeddingInputs( + audio_bounds=self._get_audio_bounds(input_ids, + audio_start_id, audio_end_id), + data=audio_embeds, + type="audio_embeds" + ) if len(audio_features) > 0: audio_features = torch.cat([ item for item in audio_features @@ -1189,8 +1268,6 @@ def _parse_and_validate_audio_inputs( audio_feature_lens = torch.cat([ item for item in audio_feature_lens ]) - audio_start_id = kwargs.pop("audio_start_id", None) - audio_end_id = kwargs.pop("audio_end_id", None) return MiniCPMVAudioFeatureInputs( audio_bounds=self._get_audio_bounds(input_ids, @@ -1236,7 +1313,6 @@ def forward( if audio_inputs is not None: vlm_embeddings = self.get_embedding_with_audios( - input_ids, vlm_embeddings, audio_inputs, self.config.audio_chunk_length