mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Model] Extend Ultravox to accept audio longer than 30s (#13631)
Signed-off-by: Farzad Abdolhosseini <farzad@fixie.ai>
This commit is contained in:
committed by
GitHub
parent
4a42b9f5d6
commit
80e78d02ac
@ -15,7 +15,7 @@ from ....conftest import HfRunner, VllmRunner
|
||||
from ....utils import RemoteOpenAIServer
|
||||
from ...utils import check_logprobs_close
|
||||
|
||||
MODEL_NAME = "fixie-ai/ultravox-v0_4"
|
||||
MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
|
||||
|
||||
AudioTuple = tuple[np.ndarray, int]
|
||||
|
||||
|
@ -1,6 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import copy
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@ -21,6 +23,7 @@ def _test_processing_correctness(
|
||||
hit_rate: float,
|
||||
num_batches: int,
|
||||
simplify_rate: float,
|
||||
ignore_mm_keys: Optional[list[str]] = None,
|
||||
):
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
@ -123,8 +126,10 @@ def _test_processing_correctness(
|
||||
hf_processor_mm_kwargs={},
|
||||
)
|
||||
|
||||
assert baseline_result == cached_result, (
|
||||
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
|
||||
assert _drop_mm_kwargs_keys(
|
||||
baseline_result, ignore_mm_keys) == _drop_mm_kwargs_keys(
|
||||
cached_result, ignore_mm_keys), (
|
||||
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
|
||||
|
||||
baseline_tokenized_result = baseline_processor.apply(
|
||||
tokenizer.encode(prompt, **tokenizer_encode_kwargs),
|
||||
@ -132,8 +137,10 @@ def _test_processing_correctness(
|
||||
hf_processor_mm_kwargs={},
|
||||
)
|
||||
|
||||
assert baseline_result == baseline_tokenized_result, (
|
||||
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
|
||||
assert _drop_mm_kwargs_keys(
|
||||
baseline_result, ignore_mm_keys) == _drop_mm_kwargs_keys(
|
||||
baseline_tokenized_result, ignore_mm_keys), (
|
||||
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
|
||||
|
||||
cached_tokenized_result = cached_processor.apply(
|
||||
tokenizer.encode(prompt, **tokenizer_encode_kwargs),
|
||||
@ -141,8 +148,10 @@ def _test_processing_correctness(
|
||||
hf_processor_mm_kwargs={},
|
||||
)
|
||||
|
||||
assert cached_result == cached_tokenized_result, (
|
||||
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
|
||||
assert _drop_mm_kwargs_keys(
|
||||
cached_result, ignore_mm_keys) == _drop_mm_kwargs_keys(
|
||||
cached_tokenized_result, ignore_mm_keys), (
|
||||
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@ -173,7 +182,7 @@ def _test_processing_correctness(
|
||||
"Qwen/Qwen2-VL-2B-Instruct",
|
||||
"Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
"Qwen/Qwen2-Audio-7B-Instruct",
|
||||
"fixie-ai/ultravox-v0_4",
|
||||
"fixie-ai/ultravox-v0_5-llama-3_2-1b",
|
||||
"openai/whisper-large-v3",
|
||||
"google/paligemma-3b-mix-224",
|
||||
"google/paligemma2-3b-ft-docci-448",
|
||||
@ -188,11 +197,19 @@ def test_processing_correctness(
|
||||
num_batches: int,
|
||||
simplify_rate: float,
|
||||
):
|
||||
ignore_mm_keys = None
|
||||
if 'ultravox' in model_id:
|
||||
# In Ultravox, the audio_features can be different depending on padding
|
||||
# The slight difference should not be a problem though, since
|
||||
# attention_mask lets us ignore the difference.
|
||||
ignore_mm_keys = ['audio_features']
|
||||
|
||||
_test_processing_correctness(
|
||||
model_id,
|
||||
hit_rate=hit_rate,
|
||||
num_batches=num_batches,
|
||||
simplify_rate=simplify_rate,
|
||||
ignore_mm_keys=ignore_mm_keys,
|
||||
)
|
||||
|
||||
|
||||
@ -221,3 +238,29 @@ def test_processing_correctness_phi3v(
|
||||
num_batches=num_batches,
|
||||
simplify_rate=simplify_rate,
|
||||
)
|
||||
|
||||
|
||||
def _drop_mm_kwargs_keys(result: dict,
|
||||
ignore_mm_keys: Optional[list[str]] = None) -> dict:
|
||||
"""Drop specified keys from result['mm_kwargs'].
|
||||
|
||||
This is mainly to avoid doing exact match of audio_features in ultravox.
|
||||
|
||||
Args:
|
||||
result: Result to drop keys from
|
||||
ignore_mm_keys: List of keys to ignore, e.g. ['audio_features']
|
||||
"""
|
||||
if not ignore_mm_keys:
|
||||
return result
|
||||
|
||||
if 'mm_kwargs' in result:
|
||||
result = copy.deepcopy(result)
|
||||
mm_kwargs = result['mm_kwargs']
|
||||
for key in ignore_mm_keys:
|
||||
mm_kwargs.pop(key, None)
|
||||
for items in mm_kwargs._items_by_modality.values():
|
||||
for item in items:
|
||||
for key in ignore_mm_keys:
|
||||
item.pop(key, None)
|
||||
|
||||
return result
|
||||
|
@ -284,8 +284,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501
|
||||
"Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501
|
||||
min_transformers_version="4.49"), # noqa: E501
|
||||
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_4",
|
||||
extras={"v0.5": "fixie-ai/ultravox-v0_5-llama-3_2-1b"}, # noqa: E501
|
||||
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501
|
||||
trust_remote_code=True),
|
||||
# [Encoder-decoder]
|
||||
# Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer
|
||||
|
@ -5,7 +5,7 @@
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from functools import cached_property
|
||||
from typing import Any, Literal, Optional, Set, Tuple, TypedDict, Union
|
||||
from typing import Any, List, Literal, Optional, Set, Tuple, TypedDict, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@ -44,12 +44,23 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
_AUDIO_PLACEHOLDER_OVERRIDE = "<|reserved_special_token_0|>"
|
||||
_AUDIO_PLACEHOLDER_TOKEN = 128002
|
||||
_AUDIO_TOKENS_PER_SECOND = 6.25
|
||||
_MAX_ENCODER_BATCH_SIZE = 16
|
||||
|
||||
|
||||
class UltravoxAudioFeatureInputs(TypedDict):
|
||||
type: Literal["audio_features"]
|
||||
data: NestedTensors
|
||||
"""Shape: `(batch_size, num_audios, 80, M)`"""
|
||||
"""Shape: `(batch_size, num_chunks, 80, M)`"""
|
||||
lens: NestedTensors
|
||||
"""
|
||||
Length of the audio frames. Used for attention mask in WhisperEncoder.
|
||||
Shape: `(batch_size, num_chunks)`
|
||||
"""
|
||||
token_len: NestedTensors
|
||||
"""
|
||||
Length of the audio tokens. Used for flattening the audio features.
|
||||
Shape: `(batch_size, num_chunks)`
|
||||
"""
|
||||
|
||||
|
||||
class UltravoxAudioEmbeddingInputs(TypedDict):
|
||||
@ -78,6 +89,7 @@ class UltravoxProcessingInfo(BaseProcessingInfo):
|
||||
# token, thus we override placeholder with a reserved special
|
||||
# token.
|
||||
hf_processor.audio_token_replacement = _AUDIO_PLACEHOLDER_OVERRIDE
|
||||
hf_processor.audio_replacement_token_id = _AUDIO_PLACEHOLDER_TOKEN
|
||||
return hf_processor
|
||||
|
||||
def get_feature_extractor(
|
||||
@ -104,7 +116,7 @@ class UltravoxProcessingInfo(BaseProcessingInfo):
|
||||
max_audio_tokens = math.ceil(feature_extractor.chunk_length *
|
||||
_AUDIO_TOKENS_PER_SECOND)
|
||||
|
||||
return {"audio": max_audio_tokens}
|
||||
return {"audio": max_audio_tokens * _MAX_ENCODER_BATCH_SIZE}
|
||||
|
||||
|
||||
class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo]
|
||||
@ -118,7 +130,8 @@ class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo]
|
||||
feature_extractor = self.info.get_feature_extractor()
|
||||
|
||||
sampling_rate = feature_extractor.sampling_rate
|
||||
audio_len = feature_extractor.chunk_length * sampling_rate
|
||||
audio_len = (feature_extractor.chunk_length * sampling_rate *
|
||||
_MAX_ENCODER_BATCH_SIZE)
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
|
||||
mm_data = {
|
||||
@ -160,41 +173,38 @@ class UltravoxMultiModalProcessor(
|
||||
mm_kwargs = dict(
|
||||
**mm_kwargs,
|
||||
sampling_rate=feature_extractor.sampling_rate,
|
||||
include_audio_num_chunks=True,
|
||||
)
|
||||
|
||||
# Ultravox processor doesn't support multiple inputs,
|
||||
# therefore we need to input text and audio one by one
|
||||
audio_features, audio_token_len = [], []
|
||||
shared_outputs = {}
|
||||
for audio in audios:
|
||||
# NOTE: Ultravox processor accepts "audio" instead of "audios"
|
||||
item_processor_data = dict(**mm_data, audio=audio)
|
||||
item_processor_data = dict(**mm_data, audios=audios)
|
||||
|
||||
item_outputs = super()._call_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data=item_processor_data,
|
||||
mm_kwargs=mm_kwargs,
|
||||
)
|
||||
|
||||
audio_features.append(item_outputs.pop("audio_values")[0])
|
||||
audio_token_len.append(item_outputs.pop("audio_token_len").item())
|
||||
shared_outputs = item_outputs
|
||||
|
||||
combined_outputs = dict(
|
||||
**shared_outputs,
|
||||
audio_features=audio_features,
|
||||
audio_token_len=audio_token_len,
|
||||
output = super()._call_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data=item_processor_data,
|
||||
mm_kwargs=mm_kwargs,
|
||||
)
|
||||
return BatchFeature(combined_outputs)
|
||||
output['audio_features'] = output.pop('audio_values')
|
||||
|
||||
return output
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
num_chunks = hf_inputs.get('audio_num_chunks', torch.zeros(0))
|
||||
return dict(
|
||||
audio_features=MultiModalFieldConfig.batched("audio"),
|
||||
audio_token_len=MultiModalFieldConfig.batched("audio"),
|
||||
# to handle longer than 30s audio, each audio might be split
|
||||
# into multiple chunks as such, their batch dimension can be
|
||||
# higher than the number of audio samples
|
||||
audio_features=MultiModalFieldConfig.flat_from_sizes(
|
||||
"audio", num_chunks),
|
||||
audio_token_len=MultiModalFieldConfig.flat_from_sizes(
|
||||
"audio", num_chunks),
|
||||
audio_lens=MultiModalFieldConfig.flat_from_sizes(
|
||||
"audio", num_chunks),
|
||||
# num_chunks can convert audio_chunked to audio batch dimension
|
||||
audio_num_chunks=MultiModalFieldConfig.batched("audio"),
|
||||
audio_embeds=MultiModalFieldConfig.batched("audio"),
|
||||
)
|
||||
|
||||
@ -205,14 +215,23 @@ class UltravoxMultiModalProcessor(
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> Sequence[PromptUpdate]:
|
||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
vocab = tokenizer.get_vocab()
|
||||
|
||||
replacement_id = vocab[
|
||||
hf_processor.audio_token_replacement] # type: ignore
|
||||
replacement_id = hf_processor.audio_replacement_token_id # type: ignore
|
||||
|
||||
# Each audio can be split into multiple chunks.
|
||||
# chunks_start_idx[i] indicates the start index of the chunks
|
||||
# belonging to the i-th audio.
|
||||
num_chunks = out_mm_kwargs.get("audio_num_chunks", torch.zeros(0))
|
||||
chunks_start_idx: torch.Tensor = torch.cumsum(num_chunks,
|
||||
dim=0,
|
||||
dtype=torch.int32)
|
||||
chunks_start_idx = torch.cat(
|
||||
[torch.tensor([0], dtype=torch.int32), chunks_start_idx])
|
||||
|
||||
def get_replacement_ultravox(item_idx: int):
|
||||
audio_token_len = out_mm_kwargs["audio_token_len"][item_idx]
|
||||
start = chunks_start_idx[item_idx]
|
||||
end = chunks_start_idx[item_idx + 1]
|
||||
audio_token_len = out_mm_kwargs["audio_token_len"][start:end].sum()
|
||||
return [replacement_id] * int(audio_token_len) # type: ignore
|
||||
|
||||
return [
|
||||
@ -304,12 +323,49 @@ class ModifiedWhisperEncoder(WhisperEncoder):
|
||||
|
||||
base_model_prefix = "model.encoder"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.config.is_decoder = False
|
||||
|
||||
@property
|
||||
def max_context_length(self):
|
||||
return (self.config.max_source_positions * self.conv1.stride[0] *
|
||||
self.conv2.stride[0])
|
||||
|
||||
def get_attention_mask_by_audio_len(self,
|
||||
audio_lens: Optional[torch.Tensor],
|
||||
hidden_states: torch.Tensor):
|
||||
"""
|
||||
Create attention mask based on audio lengths to mask out padding tokens
|
||||
For each sample in batch:
|
||||
- Convert raw audio length to feature length after convolutions
|
||||
- Create bool mask: True for valid positions and False for padding
|
||||
- Convert to attention mask format expected by transformer layers
|
||||
(1.0 for positions to attend to, large negative for positions to ignore)
|
||||
This masking ensures consistent behavior between training and inference
|
||||
by preventing the model from attending to padding tokens in both cases
|
||||
"""
|
||||
if audio_lens is None:
|
||||
return None
|
||||
|
||||
audio_feature_len = self._get_feat_extract_output_lengths(audio_lens)
|
||||
max_seq_len = hidden_states.shape[1]
|
||||
attention_mask = torch.arange(max_seq_len,
|
||||
device=hidden_states.device)[None, :].lt(
|
||||
audio_feature_len.view(-1, 1))
|
||||
attention_mask = self.get_extended_attention_mask(
|
||||
attention_mask,
|
||||
None,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
return attention_mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_features,
|
||||
input_features: torch.Tensor,
|
||||
audio_lens: Optional[torch.Tensor] = None,
|
||||
):
|
||||
expected_seq_length = (self.config.max_source_positions *
|
||||
self.conv1.stride[0] * self.conv2.stride[0])
|
||||
expected_seq_length = self.max_context_length
|
||||
if input_features.shape[-1] > expected_seq_length:
|
||||
raise ValueError(
|
||||
f"Whisper expects the mel input features to be of length "
|
||||
@ -328,10 +384,13 @@ class ModifiedWhisperEncoder(WhisperEncoder):
|
||||
p=self.dropout,
|
||||
training=self.training)
|
||||
|
||||
attention_mask = self.get_attention_mask_by_audio_len(
|
||||
audio_lens, hidden_states)
|
||||
|
||||
for encoder_layer in self.layers:
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
None,
|
||||
attention_mask,
|
||||
layer_head_mask=None,
|
||||
)
|
||||
|
||||
@ -409,17 +468,34 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
||||
)
|
||||
|
||||
def _audio_features_to_embeddings(
|
||||
self, input_features: torch.Tensor) -> torch.Tensor:
|
||||
audio_input = input_features.to(self.audio_tower.dtype)
|
||||
audio_features = self.audio_tower(audio_input)
|
||||
audio_features = audio_features.to(self.audio_tower.dtype)
|
||||
audio_embeddings = self.multi_modal_projector(audio_features)
|
||||
self, input_features: torch.Tensor,
|
||||
audio_lens: torch.Tensor) -> torch.Tensor:
|
||||
audio_features = input_features.to(self.audio_tower.dtype)
|
||||
batch_size = audio_features.size(0)
|
||||
audio_embeddings = []
|
||||
|
||||
# Process audio features in batches to keep memory usage predictable
|
||||
for start in range(0, batch_size, _MAX_ENCODER_BATCH_SIZE):
|
||||
end = min(start + _MAX_ENCODER_BATCH_SIZE, batch_size)
|
||||
# Process through audio tower
|
||||
batch_features = self.audio_tower(audio_features[start:end],
|
||||
audio_lens[start:end])
|
||||
batch_features = batch_features.to(self.audio_tower.dtype)
|
||||
|
||||
# Process through projector
|
||||
batch_embeddings = self.multi_modal_projector(batch_features)
|
||||
audio_embeddings.append(batch_embeddings)
|
||||
|
||||
# Concatenate results
|
||||
audio_embeddings = torch.cat(audio_embeddings, dim=0)
|
||||
return audio_embeddings
|
||||
|
||||
def _parse_and_validate_audio_input(
|
||||
self, **kwargs: object) -> Optional[UltravoxAudioInputs]:
|
||||
audio_features = kwargs.pop("audio_features", None)
|
||||
audio_embeds = kwargs.pop("audio_embeds", None)
|
||||
audio_lens = kwargs.pop("audio_lens", None)
|
||||
audio_token_len = kwargs.pop("audio_token_len", None)
|
||||
|
||||
if audio_features is None and audio_embeds is None:
|
||||
return None
|
||||
@ -430,7 +506,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
||||
f"Got type: {type(audio_features)}")
|
||||
|
||||
return UltravoxAudioFeatureInputs(type="audio_features",
|
||||
data=audio_features)
|
||||
data=audio_features,
|
||||
lens=audio_lens,
|
||||
token_len=audio_token_len)
|
||||
|
||||
if audio_embeds is not None:
|
||||
if not isinstance(audio_embeds, (torch.Tensor, list)):
|
||||
@ -447,34 +525,34 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
||||
if audio_input["type"] == "audio_embeds":
|
||||
return audio_input["data"]
|
||||
|
||||
audio_features = audio_input["data"]
|
||||
if isinstance(audio_features, torch.Tensor):
|
||||
# Combine the B and N dimensions for the encoder/projector
|
||||
flattened = flatten_bn(audio_features)
|
||||
flattened_embeddings = self._audio_features_to_embeddings(
|
||||
flattened)
|
||||
# Pad and concatenate audio features
|
||||
# [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)]
|
||||
audio_features = pad_and_concat_to_dim3(audio_input["data"])
|
||||
|
||||
# Restore the original dimensions
|
||||
embeddings = flattened_embeddings.unflatten(
|
||||
0, audio_features.shape[:2])
|
||||
return embeddings
|
||||
if isinstance(audio_input['lens'], list):
|
||||
# [B1, B2] -> [B1+B2]
|
||||
audio_lens = torch.cat(audio_input['lens'])
|
||||
audio_token_len = torch.cat(audio_input['token_len'])
|
||||
else:
|
||||
audio_lens = flatten_bn(audio_input['lens'])
|
||||
audio_token_len = flatten_bn(audio_input['token_len'])
|
||||
|
||||
result = []
|
||||
# TODO: Batch heterogeneous tensors through the encoder/projector
|
||||
for audio_features_item in audio_features:
|
||||
if isinstance(audio_features_item, torch.Tensor):
|
||||
result.append(
|
||||
self._audio_features_to_embeddings(audio_features_item))
|
||||
else:
|
||||
embeddings = [
|
||||
# Add a batch dimension to embed it, then remove it.
|
||||
self._audio_features_to_embeddings(tensor.unsqueeze(0)
|
||||
).squeeze(0)
|
||||
for tensor in audio_features_item
|
||||
]
|
||||
result.append(embeddings)
|
||||
embeddings = self._audio_features_to_embeddings(
|
||||
audio_features, audio_lens)
|
||||
|
||||
return result
|
||||
# We should flatten and concatenate embeddings based on token lengths
|
||||
# For example, with token_len = [4, 2, 3], flattened_embeddings will be
|
||||
# concat(embeddings[0][:4], embeddings[1][:2], embeddings[2][:3])
|
||||
|
||||
# Create a mask of valid indices based on token lengths
|
||||
max_len = embeddings.shape[1]
|
||||
indices = torch.arange(max_len, device=embeddings.device).expand(
|
||||
embeddings.shape[0], -1)
|
||||
mask = indices < audio_token_len[:, None]
|
||||
# Apply mask and flatten
|
||||
flattened_embeddings = embeddings[mask]
|
||||
|
||||
return flattened_embeddings
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs
|
||||
@ -521,7 +599,11 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
||||
with the `input_ids`.
|
||||
|
||||
Args:
|
||||
audio_features: A batch of audio inputs [B, N, 80, M].
|
||||
audio_features: A batch of audio input chunks [B, N, 80, M].
|
||||
audio_lens: Length of audio frames for each audio chunk [B].
|
||||
audio_token_len: Length of audio tokens for each audio chunk [B'].
|
||||
Note: batch dim is different from batch dim in audio chunks.
|
||||
|
||||
"""
|
||||
|
||||
if intermediate_tensors is not None:
|
||||
@ -560,3 +642,31 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
||||
loader = AutoWeightsLoader(self,
|
||||
ignore_unexpected_prefixes=["audio_tower."])
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
|
||||
def pad_and_concat_to_dim3(
|
||||
features: Union[torch.Tensor, List[torch.Tensor], List[List[torch.Tensor]]]
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Pad and concatenate a list of tensors.
|
||||
|
||||
output:
|
||||
Tensor of shape [B, C, M] where M is the maximum length of the input
|
||||
tensors, B is the sum of the batch sizes of the input tensors.
|
||||
C must be the same for all input tensors.
|
||||
"""
|
||||
if isinstance(features, torch.Tensor):
|
||||
if features.ndim > 3:
|
||||
# Flatten [B, N, 80, M] -> [B * N, 80, M]
|
||||
features = flatten_bn(features)
|
||||
return features
|
||||
|
||||
features = [pad_and_concat_to_dim3(f) for f in features]
|
||||
|
||||
max_len = max(f.shape[-1] for f in features)
|
||||
# Ensure all features have dim=3
|
||||
features = [f.view(-1, *f.shape[-2:]) for f in features]
|
||||
# Pad and oncatenate:
|
||||
# [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)]
|
||||
features = [F.pad(f, (0, max_len - f.shape[-1])) for f in features]
|
||||
return torch.cat(features)
|
||||
|
Reference in New Issue
Block a user