[Bugfix] Fix qwen3-omni audio truncation issue (#26815)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py
2025-10-15 13:38:36 +08:00
committed by GitHub
parent 7cfa420f49
commit 8c851f6d04

View File

@ -30,7 +30,9 @@ import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from packaging.version import Version
from transformers import PretrainedConfig
from transformers import __version__ as TRANSFORMERS_VERSION
from transformers.feature_extraction_utils import BatchFeature
from transformers.models.qwen3_omni_moe.configuration_qwen3_omni_moe import (
Qwen3OmniMoeConfig,
@ -711,11 +713,12 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
return x
# NOTE: WhisperFeatureExtractor cannot handle empty list of audios
feature_extractor = self.info.get_feature_extractor()
hop_length = feature_extractor.hop_length
if audios:
# NOTE: Qwen3-Omni processor accept "audio"
# To make sure the cache works with padding=True, we pre-padded
# the audio to multiple of hop_length.
hop_length = self.info.get_feature_extractor().hop_length
mm_data["audio"] = [
pad_to_hop_length(audio, hop_length)
if isinstance(audio, np.ndarray)
@ -725,6 +728,14 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
mm_kwargs = dict(
**mm_kwargs,
)
# TODO(Isotr0py): Remove this patch after upstream fix PR
# released and Transformers version update:
# https://github.com/huggingface/transformers/pull/41473
if (
Version(TRANSFORMERS_VERSION) < Version("4.58.0")
and "truncation" not in mm_kwargs
):
mm_kwargs["truncation"] = False
hf_inputs = super()._call_hf_processor(
prompt=prompt,
@ -738,7 +749,6 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
and "feature_attention_mask" in hf_inputs
and (audios := mm_data.get("audio", []))
):
hop_length = self.info.get_feature_extractor().hop_length
audio_num_frames = []
for _, audio in enumerate(audios):
audio_length = len(audio[0]) if isinstance(audio, tuple) else len(audio)
@ -747,6 +757,10 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
if audio_length % hop_length == 0
else (audio_length // hop_length - 1)
)
if mm_kwargs.get("truncation", False):
num_frame = min(
num_frame, feature_extractor.n_samples // hop_length
)
audio_num_frames.append(num_frame)
hf_inputs["feature_attention_mask"] = [
torch.ones(num_frame) for num_frame in audio_num_frames