[Model] Add Ultravox support for multiple audio chunks (#7963)

This commit is contained in:
Peter Salas
2024-09-03 21:38:21 -07:00
committed by GitHub
parent e16fa99a6a
commit 2be8ec6e71
3 changed files with 196 additions and 113 deletions

View File

@ -11,25 +11,33 @@ from vllm import LLM, SamplingParams
from vllm.assets.audio import AudioAsset
from vllm.utils import FlexibleArgumentParser
# Input audio and question
audio_and_sample_rate = AudioAsset("mary_had_lamb").audio_and_sample_rate
question = "What is recited in the audio?"
audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
question_per_audio_count = [
"What is recited in the audio?",
"What sport and what nursery rhyme are referenced?"
]
# Ultravox 0.3
def run_ultravox(question):
def run_ultravox(question, audio_count):
model_name = "fixie-ai/ultravox-v0_3"
tokenizer = AutoTokenizer.from_pretrained(model_name)
messages = [{
'role': 'user',
'content': f"<|reserved_special_token_0|>\n{question}"
'role':
'user',
'content':
"<|reserved_special_token_0|>\n" * audio_count + question
}]
prompt = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
llm = LLM(model=model_name)
llm = LLM(model=model_name,
enforce_eager=True,
enable_chunked_prefill=False,
max_model_len=8192,
limit_mm_per_prompt={"audio": audio_count})
stop_token_ids = None
return llm, prompt, stop_token_ids
@ -44,7 +52,9 @@ def main(args):
if model not in model_example_map:
raise ValueError(f"Model type {model} is not supported.")
llm, prompt, stop_token_ids = model_example_map[model](question)
audio_count = args.num_audios
llm, prompt, stop_token_ids = model_example_map[model](
question_per_audio_count[audio_count - 1], audio_count)
# We set temperature to 0.2 so that outputs can be different
# even when all prompts are identical when running batch inference.
@ -53,23 +63,18 @@ def main(args):
stop_token_ids=stop_token_ids)
assert args.num_prompts > 0
if args.num_prompts == 1:
# Single inference
inputs = {
"prompt": prompt,
"multi_modal_data": {
"audio": audio_and_sample_rate
},
}
else:
inputs = {
"prompt": prompt,
"multi_modal_data": {
"audio": [
asset.audio_and_sample_rate
for asset in audio_assets[:audio_count]
]
},
}
if args.num_prompts > 1:
# Batch inference
inputs = [{
"prompt": prompt,
"multi_modal_data": {
"audio": audio_and_sample_rate
},
} for _ in range(args.num_prompts)]
inputs = [inputs] * args.num_prompts
outputs = llm.generate(inputs, sampling_params=sampling_params)
@ -92,6 +97,11 @@ if __name__ == "__main__":
type=int,
default=1,
help='Number of prompts to run.')
parser.add_argument("--num-audios",
type=int,
default=1,
choices=[1, 2],
help="Number of audio items per prompt.")
args = parser.parse_args()
main(args)

View File

@ -16,37 +16,32 @@ MODEL_NAME = "fixie-ai/ultravox-v0_3"
AudioTuple = Tuple[np.ndarray, int]
VLLM_PLACEHOLDER = "<|reserved_special_token_0|>"
HF_PLACEHOLDER = "<|audio|>"
@pytest.fixture(scope="session")
def audio_and_sample_rate():
def audio_assets():
from vllm.assets.audio import AudioAsset
return AudioAsset("mary_had_lamb").audio_and_sample_rate
return [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
@pytest.fixture
def prompts_and_audios(audio_and_sample_rate):
@pytest.fixture(scope="module", params=("mary_had_lamb", "winning_call"))
def audio(request):
from vllm.assets.audio import AudioAsset
return AudioAsset(request.param)
def _get_prompt(audio_count, question, placeholder):
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
placeholder = f"{placeholder}\n" * audio_count
vllm_placeholder = "<|reserved_special_token_0|>"
hf_placeholder = "<|audio|>"
question = "What's in the audio?"
vllm_prompt = tokenizer.apply_chat_template(
[{
'role': 'user',
'content': f"{vllm_placeholder}\n{question}"
}],
tokenize=False,
add_generation_prompt=True)
hf_prompt = tokenizer.apply_chat_template(
[{
'role': 'user',
'content': f"{hf_placeholder}\n{question}"
}],
tokenize=False,
add_generation_prompt=True)
return [(vllm_prompt, hf_prompt, audio_and_sample_rate)]
return tokenizer.apply_chat_template([{
'role': 'user',
'content': f"{placeholder}{question}"
}],
tokenize=False,
add_generation_prompt=True)
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
@ -134,15 +129,71 @@ def run_test(
)
def run_multi_audio_test(
vllm_runner: Type[VllmRunner],
prompts_and_audios: List[Tuple[str, List[AudioTuple]]],
model: str,
*,
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
with vllm_runner(model,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True,
limit_mm_per_prompt={
"audio":
max((len(audio) for _, audio in prompts_and_audios))
}) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
[prompt for prompt, _ in prompts_and_audios],
max_tokens,
num_logprobs=num_logprobs,
audios=[audios for _, audios in prompts_and_audios])
# The HuggingFace model doesn't support multiple audios yet, so
# just assert that some tokens were generated.
assert all(tokens for tokens, *_ in vllm_outputs)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner, vllm_runner, prompts_and_audios, dtype: str,
max_tokens: int, num_logprobs: int) -> None:
def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int,
num_logprobs: int) -> None:
vllm_prompt = _get_prompt(1, "Describe the audio above.", VLLM_PLACEHOLDER)
hf_prompt = _get_prompt(1, "Describe the audio above.", HF_PLACEHOLDER)
run_test(
hf_runner,
vllm_runner,
prompts_and_audios,
[(vllm_prompt, hf_prompt, audio.audio_and_sample_rate)],
MODEL_NAME,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str,
max_tokens: int,
num_logprobs: int) -> None:
vllm_prompt = _get_prompt(len(audio_assets),
"Describe each of the audios above.",
VLLM_PLACEHOLDER)
run_multi_audio_test(
vllm_runner,
[(vllm_prompt, [audio.audio_and_sample_rate
for audio in audio_assets])],
MODEL_NAME,
dtype=dtype,
max_tokens=max_tokens,

View File

@ -29,12 +29,12 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.utils import (filter_weights,
from vllm.model_executor.models.utils import (filter_weights, flatten_bn,
init_vllm_registered_model,
merge_multimodal_embeddings)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.base import MultiModalInputs, NestedTensors
from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
@ -48,13 +48,14 @@ logger = init_logger(__name__)
class UltravoxAudioFeatureInputs(TypedDict):
type: Literal["audio_features"]
data: Union[torch.Tensor, List[torch.Tensor]]
"""Shape: `(batch_size * num_audios, 80, M)"""
data: NestedTensors
"""Shape: `(batch_size, num_audios, 80, M)"""
class UltravoxAudioEmbeddingInputs(TypedDict):
type: Literal["audio_embeds"]
data: torch.Tensor
data: NestedTensors
"""Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)"""
UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs,
@ -85,24 +86,33 @@ def dummy_data_for_ultravox(
audio_count = mm_counts["audio"]
audio_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [
_AUDIO_PLACEHOLDER_TOKEN
]) * get_ultravox_max_audio_tokens(ctx) * audio_count
audio_placeholder = array(
VLLM_TOKEN_ID_ARRAY_TYPE,
[_AUDIO_PLACEHOLDER_TOKEN]) * get_ultravox_max_audio_tokens(ctx)
# Add a separator between each chunk.
audio_token_ids = (audio_placeholder +
array(VLLM_TOKEN_ID_ARRAY_TYPE, [0])) * audio_count
other_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - len(audio_token_ids))
audio_and_sr = (np.array([0.0] * feature_extractor.chunk_length), 1)
mm_dict = {
"audio":
audio_and_sr if audio_count == 1 else [audio_and_sr] * audio_count
}
mm_dict = {"audio": [audio_and_sr] * audio_count}
return (SequenceData(audio_token_ids + other_token_ids), mm_dict)
def input_mapper_for_ultravox(ctx: InputContext, data: object):
if isinstance(data, tuple):
(audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], data)
if not isinstance(data, list):
data = [data]
audio_features = []
for audio_input in data:
if not isinstance(audio_input, tuple):
raise NotImplementedError(
f"Unsupported data type: {type(audio_input)}")
(audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], audio_input)
feature_extractor = whisper_feature_extractor(ctx)
if sr != feature_extractor.sampling_rate:
@ -121,15 +131,14 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object):
# Not enough audio; pad it.
audio = np.pad(audio, (0, minimum_audio_length - len(audio)))
return MultiModalInputs({
"audio_features":
feature_extractor(audio,
sampling_rate=sr,
padding="longest",
return_tensors="pt")["input_features"]
})
single_audio_features = feature_extractor(
audio, sampling_rate=sr, padding="longest",
return_tensors="pt")["input_features"]
raise NotImplementedError(f"Unsupported data type: {type(data)}")
# Remove the batch dimension because we're wrapping it in a list.
audio_features.append(single_audio_features.squeeze(0))
return MultiModalInputs({"audio_features": audio_features})
def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs):
@ -138,25 +147,31 @@ def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs):
return llm_inputs
feature_extractor = whisper_feature_extractor(ctx)
audio_data, sample_rate = multi_modal_data["audio"]
audios = multi_modal_data["audio"]
if not isinstance(audios, list):
audios = [audios]
audio_length = audio_data.shape[0]
if sample_rate != feature_extractor.sampling_rate:
# Account for resampling.
adjustment = feature_extractor.sampling_rate / sample_rate
audio_length = math.ceil(adjustment * audio_length)
audio_token_counts = []
for audio_data, sample_rate in audios:
audio_length = audio_data.shape[0]
if sample_rate != feature_extractor.sampling_rate:
# Account for resampling.
adjustment = feature_extractor.sampling_rate / sample_rate
audio_length = math.ceil(adjustment * audio_length)
feature_extractor_output_length = math.ceil(
(audio_length -
(feature_extractor.hop_length - 1)) / feature_extractor.hop_length)
feature_extractor_output_length = math.ceil(
(audio_length - (feature_extractor.hop_length - 1)) /
feature_extractor.hop_length)
uv_config = ctx.get_hf_config(UltravoxConfig)
audio_num_tokens = min(
max(
1,
math.ceil(feature_extractor_output_length /
(uv_config.stack_factor * 2))),
get_ultravox_max_audio_tokens(ctx))
audio_token_counts.append(audio_num_tokens)
uv_config = ctx.get_hf_config(UltravoxConfig)
audio_num_tokens = min(
max(
1,
math.ceil(feature_extractor_output_length /
(uv_config.stack_factor * 2))),
get_ultravox_max_audio_tokens(ctx))
tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
@ -164,7 +179,7 @@ def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs):
llm_inputs.get("prompt"),
llm_inputs["prompt_token_ids"],
placeholder_token_id=_AUDIO_PLACEHOLDER_TOKEN,
repeat_count=audio_num_tokens,
repeat_count=audio_token_counts,
)
# NOTE: Create a defensive copy of the original inputs
@ -338,45 +353,52 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
raise ValueError("Incorrect type of audio features. "
f"Got type: {type(audio_features)}")
# Remove the N dimension until multiple audios are supported.
if isinstance(audio_features, torch.Tensor):
audio_features = audio_features.squeeze(1)
else:
audio_features = [t.squeeze(0) for t in audio_features]
return UltravoxAudioFeatureInputs(type="audio_features",
data=audio_features)
if audio_embeds is not None:
if not isinstance(audio_embeds, torch.Tensor):
if not isinstance(audio_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio embeds. "
f"Got type: {type(audio_embeds)}")
# Remove the N dimension until multiple audios are supported.
audio_embeds = audio_embeds.squeeze(1)
return UltravoxAudioEmbeddingInputs(type="audio_embeds",
data=audio_embeds)
raise AssertionError("This line should be unreachable.")
def _process_audio_input(
self, audio_input: UltravoxAudioInputs
) -> Union[torch.Tensor, List[torch.Tensor]]:
self, audio_input: UltravoxAudioInputs) -> NestedTensors:
if audio_input["type"] == "audio_embeds":
return audio_input["data"]
audio_features = audio_input["data"]
if isinstance(audio_features, list):
# TODO: Batch these through the encoder/projector instead of
# serializing them.
return [
self._audio_features_to_embeddings(
features.unsqueeze(0)).squeeze(0)
for features in audio_features
]
else:
return self._audio_features_to_embeddings(audio_features)
if isinstance(audio_features, torch.Tensor):
# Combine the B and N dimensions for the encoder/projector
flattened = flatten_bn(audio_features)
flattened_embeddings = self._audio_features_to_embeddings(
flattened)
# Restore the original dimensions
embeddings = flattened_embeddings.unflatten(
0, audio_features.shape[:2])
return embeddings
result = []
# TODO: Batch heterogeneous tensors through the encoder/projector
for audio_features_item in audio_features:
if isinstance(audio_features_item, torch.Tensor):
result.append(
self._audio_features_to_embeddings(audio_features_item))
else:
embeddings = [
# Add a batch dimension to embed it, then remove it.
self._audio_features_to_embeddings(tensor.unsqueeze(0)
).squeeze(0)
for tensor in audio_features_item
]
result.append(embeddings)
return result
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
@ -393,7 +415,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
with the `input_ids`.
Args:
input_features: A batch of audio inputs, [1, 80, M].
audio_features: A batch of audio inputs [B, N, 80, M].
"""
audio_input = self._parse_and_validate_audio_input(**kwargs)
if audio_input is not None: