Skip to content

Commit

Permalink
audio embedding inputs
Browse files Browse the repository at this point in the history
Signed-off-by: hzh <[email protected]>
  • Loading branch information
HwwwwwwwH committed Jan 22, 2025
1 parent c2d8dbb commit 1ba77eb
Showing 1 changed file with 93 additions and 17 deletions.
110 changes: 93 additions & 17 deletions vllm/model_executor/models/minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,22 @@ class MiniCPMVAudioFeatureInputs(TypedDict):
"""


class MiniCPMVAudioEmbeddingInputs(TypedDict):
type: Literal["audio_embeds"]
data: torch.Tensor
"""
Shape:
"""
audio_bounds: torch.Tensor
"""
Shape:
"""


MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs,
MiniCPMVImageEmbeddingInputs]
MiniCPMVAudioInputs = Union[MiniCPMVAudioFeatureInputs]
MiniCPMVAudioInputs = Union[MiniCPMVAudioFeatureInputs,
MiniCPMVAudioEmbeddingInputs]


class MiniCPMVEmbeddingItems(ModalityDataItems[dict[str, torch.Tensor],
Expand Down Expand Up @@ -227,6 +240,19 @@ def get_num_frames(self, index: int) -> int:
return self.data["num_frames"][index]


class MiniCPMVAudioEmbeddingItems(MiniCPMVEmbeddingItems):
def __init__(self, data: Dict) -> None:
super().__init__(data, "audio")
audio_embeds = self.data.get("audio_embeds", None)
if audio_embeds is None:
raise ValueError(f"In correct type of video_embeds",
f"Got type: None")
self.data["audio_embeds"] = audio_embeds

def get(self, index: int) -> object:
return self.data["audio_embeds"][index]


DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)


Expand Down Expand Up @@ -346,7 +372,6 @@ def _parse_image_data(
) -> ModalityDataItems[Any, Any]:
if isinstance(data, dict):
return MiniCPMVImageEmbeddingItems(data)

return super()._parse_image_data(data)

def _parse_video_data(
Expand All @@ -355,9 +380,16 @@ def _parse_video_data(
) -> ModalityDataItems[Any, Any]:
if isinstance(data, dict):
return MiniCPMVVideoEmbeddingItems(data)

return super()._parse_video_data(data)

def _parse_audio_data(
self,
data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]],
) -> ModalityDataItems[Any, Any]:
if isinstance(data, dict):
return MiniCPMVAudioEmbeddingItems(data)
return super()._parse_audio_data(data)


class MiniCPMVProcessingInfo(BaseProcessingInfo):
image_pattern = "(<image>./</image>)"
Expand Down Expand Up @@ -507,7 +539,7 @@ def get_default_audio_sampling_rate(self):

def get_chunk_length(self) -> int:
return self.get_hf_config().audio_chunk_length

def get_max_audio_tokens_per_chunk(self) -> int:
pool_step = self.get_default_audio_pool_step()
fbank_feat_in_chunk = 100
Expand All @@ -518,6 +550,12 @@ def get_max_audio_tokens_per_chunk(self) -> int:
def get_max_audio_chunks_with_most_features(self) -> int:
return 30

def get_audio_len_by_num_chunks(self, num_chunks: int) -> int:
sampling_rate = self.get_default_audio_sampling_rate()
# exclude <audio> </audio>
num_tokens_per_chunk = self.get_max_audio_tokens_per_chunk() - 2
return int(num_chunks * sampling_rate / num_tokens_per_chunk) + 1


class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[MiniCPMVProcessingInfo]):
def get_dummy_processor_inputs(
Expand Down Expand Up @@ -723,14 +761,30 @@ def process_audios(
single_audio_outputs["audio_features"]
)
audio_outputs["audio_num_segments"].append(
len(single_audio_outputs["audio_feature_lens"])
len(single_audio_outputs["audio_feature_lens"][0])
)
audio_outputs["audio_feature_lens"] += \
single_audio_outputs["audio_feature_lens"]
audio_outputs["audio_features"] = torch.cat(audio_outputs["audio_features"])
audio_outputs["audio_feature_lens"] = torch.cat(audio_outputs["audio_feature_lens"])
elif len(audio_embeds):
pass
audio_outputs = {
"audio_lens": [
self.info.get_audio_len_by_num_chunks(
sum(chunk_embeds.shape[0] for chunk_embeds in single_audio_embeds)
)
for single_audio_embeds in audio_embeds
],
"audio_embeds": [
chunk_embeds
for single_audio_embeds in audio_embeds
for chunk_embeds in single_audio_embeds
],
"audio_num_segments": [
len(single_audio_embeds)
for single_audio_embeds in audio_embeds
]
}
else:
audio_outputs = {}
return audio_outputs
Expand All @@ -744,11 +798,9 @@ def _call_hf_processor(
# Do not support combination inputs of images and videos for now
# Try to handle interleaved multimodal data
tokenizer = self.info.get_tokenizer()

image_outputs = self.process_images(mm_data, mm_kwargs)
video_outputs = self.process_videos(mm_data, mm_kwargs)
audio_outputs = self.process_audios(mm_data, mm_kwargs)

counts = {
"image": 0,
"video": 0,
Expand Down Expand Up @@ -809,12 +861,12 @@ def get_slices(num_slices: List[int]):
return {
"input_ids": np.array([input_ids]),
**image_outputs,
**video_outputs,
**audio_outputs,
"image_orders_in_mm_data": image_orders_in_mm_data,
"image_slices": get_slices(num_image_slices),
**video_outputs,
"video_orders_in_mm_data": video_orders_in_mm_data,
"video_slices": get_slices(num_video_slices),
**audio_outputs,
"audio_orders_in_mm_data": audio_orders_in_mm_data,
"audio_slices": get_slices(num_audio_slices),
}
Expand Down Expand Up @@ -843,6 +895,12 @@ def get_replacement_minicpmv(item_idx: int, modality: str):
mm_items["video"].get_num_frames(item_idx)
)
else: # audio
if isinstance(mm_items["audio"], MiniCPMVAudioEmbeddingItems):
single_audio_embeds = mm_items["audio"].get(item_idx)
audio_len = self.info.get_audio_len_by_num_chunks(
sum(chunk_embeds.shape[0] for chunk_embeds in single_audio_embeds)
)
return self.get_audio_prompt_texts(audio_len)
return self.get_audio_prompt_texts(
len(mm_items["audio"].get(item_idx))
)
Expand Down Expand Up @@ -881,9 +939,9 @@ def get_slices(slices_indices: List[int]):
video_slices=MultiModalFieldConfig.batched("video"),
audio_features=MultiModalFieldConfig.flat("audio", audio_slices),
audio_feature_lens=MultiModalFieldConfig.flat("audio", audio_slices),
audio_num_segments=MultiModalFieldConfig.batched("audio"),
audio_slices=MultiModalFieldConfig.batched("audio"),
audio_orders_in_mm_data=MultiModalFieldConfig.batched("audio")
audio_orders_in_mm_data=MultiModalFieldConfig.batched("audio"),
audio_embeds=MultiModalFieldConfig.flat("audio", audio_slices)
)

def apply(
Expand Down Expand Up @@ -990,13 +1048,21 @@ def get_embedding_with_vision(

def get_embedding_with_audios(
self,
input_ids: torch.Tensor,
vlm_embedding: torch.Tensor,
audio_inputs: Optional[MiniCPMVAudioInputs],
chunk_length: int
) -> torch.Tensor:
device, dtype = vlm_embedding.device, vlm_embedding.dtype
audio_embeddings = self.get_audio_hidden_states(audio_inputs, chunk_length)[0]
if audio_inputs["type"] == "audio_embeds":
audio_embeddings = audio_inputs["data"]
audio_embeddings = [
audio_embeddings[i].to(device=device, dtype=dtype)
for i in range(len(audio_embeddings))
]
else:
audio_embeddings = self.get_audio_hidden_states(audio_inputs, chunk_length)[0]
if audio_embeddings is None or len(audio_embeddings) == 0:
return vlm_embedding
audio_bounds = audio_inputs["audio_bounds"]
if self.config.chunk_input:
audio_embs = torch.cat(audio_embeddings, dim=0).to(
Expand Down Expand Up @@ -1182,15 +1248,26 @@ def _parse_and_validate_audio_inputs(
) -> Tuple[MiniCPMVImageInputs]:
audio_features = kwargs.pop("audio_features", [])
audio_feature_lens = kwargs.pop("audio_feature_lens", [])
audio_embeds = kwargs.pop("audio_embeds", None)
audio_start_id = kwargs.pop("audio_start_id", None)
audio_end_id = kwargs.pop("audio_end_id", None)
if audio_embeds is not None:
audio_embeds = [audio_embeds[i][j]
for i in range(len(audio_embeds))
for j in range(len(audio_embeds[i]))]
return MiniCPMVAudioEmbeddingInputs(
audio_bounds=self._get_audio_bounds(input_ids,
audio_start_id, audio_end_id),
data=audio_embeds,
type="audio_embeds"
)
if len(audio_features) > 0:
audio_features = torch.cat([
item for item in audio_features
])
audio_feature_lens = torch.cat([
item for item in audio_feature_lens
])
audio_start_id = kwargs.pop("audio_start_id", None)
audio_end_id = kwargs.pop("audio_end_id", None)

return MiniCPMVAudioFeatureInputs(
audio_bounds=self._get_audio_bounds(input_ids,
Expand Down Expand Up @@ -1236,7 +1313,6 @@ def forward(

if audio_inputs is not None:
vlm_embeddings = self.get_embedding_with_audios(
input_ids,
vlm_embeddings,
audio_inputs,
self.config.audio_chunk_length
Expand Down

0 comments on commit 1ba77eb

Please sign in to comment.