mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Add Qwen3-Omni moe thinker (#25550)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Roger Wang <hey@rogerw.io> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Xiong Wang <feizi.wx@alibaba-inc.com> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: Roger Wang <hey@rogerw.io> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@ -714,6 +714,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|
||||
| `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup> | `Qwen/Qwen2.5-Omni-3B`, `Qwen/Qwen2.5-Omni-7B` | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Qwen3VLForConditionalGeneration` | Qwen3-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen3-VL-4B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Qwen3VLMoeForConditionalGeneration` | Qwen3-VL-MOE | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen3-VL-30B-A3B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Qwen3OmniMoeThinkerForConditionalGeneration` | Qwen3-Omni | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup> | `Qwen/Qwen3-Omni-30B-A3B-Instruct`, `Qwen/Qwen3-Omni-30B-A3B-Thinking` | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `RForConditionalGeneration` | R-VL-4B | T + I<sup>E+</sup> | `YannQi/R-4B` | | ✅︎ | ✅︎ |
|
||||
| `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | ✅︎ |
|
||||
| `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | ✅︎ |
|
||||
@ -804,8 +805,7 @@ Some models are supported only via the [Transformers backend](#transformers). Th
|
||||
Our PaliGemma implementations have the same problem as Gemma 3 (see above) for both V0 and V1.
|
||||
|
||||
!!! note
|
||||
For Qwen2.5-Omni, reading audio from video pre-processing (`--mm-processor-kwargs '{"use_audio_in_video": true}'`)
|
||||
is currently supported on V0 (but not V1), because overlapping modalities is not yet supported in V1.
|
||||
For Qwen2.5-Omni and Qwen3-Omni, reading audio from video pre-processing (`--mm-processor-kwargs '{"use_audio_in_video": true}'`) is currently work in progress and not yet supported.
|
||||
|
||||
#### Transcription
|
||||
|
||||
|
@ -384,6 +384,7 @@ def _test_processing_correctness_one(
|
||||
"Qwen/Qwen2.5-Omni-3B",
|
||||
"Qwen/Qwen3-VL-4B-Instruct",
|
||||
"Qwen/Qwen3-VL-30B-A3B-Instruct",
|
||||
"Qwen/Qwen3-Omni-30B-A3B-Instruct",
|
||||
"YannQi/R-4B",
|
||||
"Skywork/Skywork-R1V-38B",
|
||||
"HuggingFaceTB/SmolVLM2-2.2B-Instruct",
|
||||
|
@ -773,6 +773,11 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
min_transformers_version="4.57",
|
||||
is_available_online=False,
|
||||
),
|
||||
"Qwen3OmniMoeForConditionalGeneration": _HfExamplesInfo(
|
||||
"Qwen/Qwen3-Omni-30B-A3B-Instruct",
|
||||
max_model_len=4096,
|
||||
min_transformers_version="4.57",
|
||||
),
|
||||
"RForConditionalGeneration": _HfExamplesInfo("YannQi/R-4B", trust_remote_code=True),
|
||||
"SkyworkR1VChatModel": _HfExamplesInfo(
|
||||
"Skywork/Skywork-R1V-38B", trust_remote_code=True
|
||||
|
@ -971,17 +971,9 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
||||
)
|
||||
|
||||
t_index = (
|
||||
(
|
||||
torch.arange(llm_grid_t)
|
||||
.view(-1, 1)
|
||||
.expand(-1, llm_grid_h * llm_grid_w)
|
||||
)
|
||||
.long()
|
||||
.flatten()
|
||||
)
|
||||
|
||||
torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w)
|
||||
).flatten()
|
||||
h_index = (
|
||||
torch.arange(llm_grid_h)
|
||||
.view(1, -1, 1)
|
||||
@ -1042,7 +1034,6 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
|
||||
st = 0
|
||||
remain_images, remain_videos = image_nums, video_nums
|
||||
|
||||
image_index, video_index = 0, 0
|
||||
for _ in range(image_nums + video_nums):
|
||||
video_second_per_grid_t = 0.0
|
||||
@ -1093,19 +1084,11 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
||||
)
|
||||
|
||||
t_index = (
|
||||
(
|
||||
torch.arange(llm_grid_t)
|
||||
.view(-1, 1)
|
||||
.expand(-1, llm_grid_h * llm_grid_w)
|
||||
* video_second_per_grid_t
|
||||
* tokens_per_second
|
||||
)
|
||||
.long()
|
||||
.flatten()
|
||||
)
|
||||
|
||||
torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w)
|
||||
* video_second_per_grid_t
|
||||
* tokens_per_second
|
||||
).flatten()
|
||||
h_index = (
|
||||
torch.arange(llm_grid_h)
|
||||
.view(1, -1, 1)
|
||||
@ -1136,6 +1119,339 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
|
||||
return llm_positions, mrope_position_delta
|
||||
|
||||
@classmethod
|
||||
def _omni3_get_input_positions_tensor(
|
||||
cls,
|
||||
config,
|
||||
input_ids: torch.Tensor,
|
||||
image_grid_thw: torch.Tensor,
|
||||
video_grid_thw: torch.Tensor,
|
||||
use_audio_in_video: bool = False,
|
||||
audio_seqlens: Optional[torch.Tensor] = None,
|
||||
second_per_grids: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
def _get_feat_extract_output_lengths(input_lengths: torch.LongTensor):
|
||||
input_lengths_leave = input_lengths % 100
|
||||
feat_lengths = (input_lengths_leave - 1) // 2 + 1
|
||||
output_lengths = (
|
||||
((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
|
||||
)
|
||||
return output_lengths
|
||||
|
||||
if input_ids is None or input_ids.ndim != 1:
|
||||
raise ValueError("_omni3_get_input_positions_tensor expects 1D input_ids")
|
||||
|
||||
seq_len = input_ids.shape[0]
|
||||
device = input_ids.device
|
||||
dtype = input_ids.dtype
|
||||
|
||||
if image_grid_thw is not None:
|
||||
image_grid_thw = image_grid_thw.to(device=device, dtype=torch.long)
|
||||
if video_grid_thw is not None:
|
||||
video_grid_thw = video_grid_thw.to(device=device, dtype=torch.long)
|
||||
|
||||
if second_per_grids is None:
|
||||
if video_grid_thw is not None and video_grid_thw.numel() > 0:
|
||||
second_per_grids = torch.ones(
|
||||
video_grid_thw.shape[0], dtype=torch.float32, device=device
|
||||
)
|
||||
else:
|
||||
second_per_grids = torch.tensor([], dtype=torch.float32, device=device)
|
||||
else:
|
||||
second_per_grids = second_per_grids.to(device=device, dtype=torch.float32)
|
||||
|
||||
if audio_seqlens is not None:
|
||||
audio_seqlens = audio_seqlens.to(device=device, dtype=torch.long)
|
||||
|
||||
spatial_merge_size = config.vision_config.spatial_merge_size
|
||||
image_token_id = config.image_token_id
|
||||
video_token_id = config.video_token_id
|
||||
audio_token_id = config.audio_token_id
|
||||
vision_start_token_id = config.vision_start_token_id
|
||||
audio_start_token_id = config.audio_start_token_id
|
||||
position_id_per_seconds = config.position_id_per_seconds
|
||||
|
||||
vision_start_indices = torch.argwhere(
|
||||
input_ids == vision_start_token_id
|
||||
).squeeze(1)
|
||||
if vision_start_indices.numel() > 0:
|
||||
vision_tokens = input_ids[vision_start_indices + 1]
|
||||
else:
|
||||
vision_tokens = input_ids.new_empty((0,), dtype=input_ids.dtype)
|
||||
audio_nums = torch.sum(input_ids == audio_start_token_id)
|
||||
image_nums = (vision_tokens == image_token_id).sum()
|
||||
video_nums = (
|
||||
(vision_tokens == audio_start_token_id).sum()
|
||||
if use_audio_in_video
|
||||
else (vision_tokens == video_token_id).sum()
|
||||
)
|
||||
|
||||
input_tokens = input_ids.tolist()
|
||||
llm_pos_ids_list: list[torch.Tensor] = []
|
||||
st = 0
|
||||
image_idx = 0
|
||||
video_idx = 0
|
||||
audio_idx = 0
|
||||
remain_images, remain_videos, remain_audios = image_nums, video_nums, audio_nums # noqa: E501
|
||||
multimodal_nums = (
|
||||
image_nums + audio_nums
|
||||
if use_audio_in_video
|
||||
else image_nums + video_nums + audio_nums
|
||||
) # noqa: E501
|
||||
|
||||
for _ in range(multimodal_nums):
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
if (image_token_id in input_tokens or video_token_id in input_tokens) and (
|
||||
remain_videos > 0 or remain_images > 0
|
||||
):
|
||||
ed_vision_start = input_tokens.index(vision_start_token_id, st)
|
||||
else:
|
||||
ed_vision_start = len(input_tokens) + 1
|
||||
if audio_token_id in input_tokens and remain_audios > 0:
|
||||
ed_audio_start = input_tokens.index(audio_start_token_id, st)
|
||||
else:
|
||||
ed_audio_start = len(input_tokens) + 1
|
||||
min_ed = min(ed_vision_start, ed_audio_start)
|
||||
|
||||
if min_ed == ed_audio_start:
|
||||
text_len = min_ed - st
|
||||
if text_len != 0:
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len, device=device, dtype=torch.long)
|
||||
.view(1, -1)
|
||||
.expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
bos_len = 1
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(bos_len, device=device, dtype=torch.long)
|
||||
.view(1, -1)
|
||||
.expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
audio_len = _get_feat_extract_output_lengths(audio_seqlens[audio_idx])
|
||||
llm_pos_ids = (
|
||||
torch.arange(audio_len, device=device, dtype=torch.long)
|
||||
.view(1, -1)
|
||||
.expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
llm_pos_ids_list.append(llm_pos_ids)
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
eos_len = 1
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(eos_len, device=device, dtype=torch.long)
|
||||
.view(1, -1)
|
||||
.expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
st += text_len + bos_len + audio_len + eos_len
|
||||
audio_idx += 1
|
||||
remain_audios -= 1
|
||||
elif (
|
||||
min_ed == ed_vision_start
|
||||
and input_ids[ed_vision_start + 1] == image_token_id
|
||||
):
|
||||
text_len = min_ed - st
|
||||
if text_len != 0:
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len, device=device, dtype=torch.long)
|
||||
.view(1, -1)
|
||||
.expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
bos_len = 1
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(bos_len, device=device, dtype=torch.long)
|
||||
.view(1, -1)
|
||||
.expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
grid_t = image_grid_thw[image_idx][0]
|
||||
grid_hs = image_grid_thw[:, 1]
|
||||
grid_ws = image_grid_thw[:, 2]
|
||||
t_index = torch.arange(grid_t, device=device) * position_id_per_seconds
|
||||
llm_pos_ids = cls._get_llm_pos_ids_for_vision(
|
||||
st_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws
|
||||
)
|
||||
image_len = image_grid_thw[image_idx].prod() // (spatial_merge_size**2)
|
||||
llm_pos_ids_list.append(llm_pos_ids)
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
eos_len = 1
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(eos_len, device=device, dtype=torch.long)
|
||||
.view(1, -1)
|
||||
.expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
st += text_len + bos_len + image_len + eos_len
|
||||
image_idx += 1
|
||||
remain_images -= 1
|
||||
elif (
|
||||
min_ed == ed_vision_start
|
||||
and input_ids[ed_vision_start + 1] == video_token_id
|
||||
and not use_audio_in_video
|
||||
):
|
||||
text_len = min_ed - st
|
||||
if text_len != 0:
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len, device=device, dtype=torch.long)
|
||||
.view(1, -1)
|
||||
.expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
bos_len = 1
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(bos_len, device=device, dtype=torch.long)
|
||||
.view(1, -1)
|
||||
.expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
grid_t = video_grid_thw[video_idx][0]
|
||||
grid_hs = video_grid_thw[:, 1]
|
||||
grid_ws = video_grid_thw[:, 2]
|
||||
t_index = (
|
||||
torch.arange(grid_t, device=device)
|
||||
* float(second_per_grids[video_idx].item())
|
||||
* position_id_per_seconds
|
||||
)
|
||||
llm_pos_ids = cls._get_llm_pos_ids_for_vision(
|
||||
st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
|
||||
)
|
||||
video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2)
|
||||
llm_pos_ids_list.append(llm_pos_ids)
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
eos_len = 1
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(eos_len, device=device, dtype=torch.long)
|
||||
.view(1, -1)
|
||||
.expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
st += text_len + bos_len + video_len + eos_len
|
||||
video_idx += 1
|
||||
remain_videos -= 1
|
||||
elif (
|
||||
min_ed == ed_vision_start
|
||||
and ed_vision_start + 1 == ed_audio_start
|
||||
and use_audio_in_video
|
||||
):
|
||||
text_len = min_ed - st
|
||||
if text_len != 0:
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len, device=device, dtype=torch.long)
|
||||
.view(1, -1)
|
||||
.expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
bos_len = 1
|
||||
bos_block = (
|
||||
torch.arange(bos_len, device=device, dtype=torch.long)
|
||||
.view(1, -1)
|
||||
.expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
llm_pos_ids_list.append(bos_block)
|
||||
llm_pos_ids_list.append(bos_block)
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
audio_len = _get_feat_extract_output_lengths(audio_seqlens[audio_idx])
|
||||
audio_llm_pos_ids = (
|
||||
torch.arange(audio_len, device=device, dtype=torch.long)
|
||||
.view(1, -1)
|
||||
.expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
grid_t = video_grid_thw[video_idx][0]
|
||||
grid_hs = video_grid_thw[:, 1]
|
||||
grid_ws = video_grid_thw[:, 2]
|
||||
t_index = (
|
||||
torch.arange(grid_t, device=device)
|
||||
* float(second_per_grids[video_idx].item())
|
||||
* position_id_per_seconds
|
||||
)
|
||||
video_llm_pos_ids = cls._get_llm_pos_ids_for_vision(
|
||||
st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
|
||||
)
|
||||
video_data_index, audio_data_index = 0, 0
|
||||
while (
|
||||
video_data_index < video_llm_pos_ids.shape[-1]
|
||||
and audio_data_index < audio_llm_pos_ids.shape[-1]
|
||||
):
|
||||
if (
|
||||
video_llm_pos_ids[0][video_data_index]
|
||||
<= audio_llm_pos_ids[0][audio_data_index]
|
||||
):
|
||||
llm_pos_ids_list.append(
|
||||
video_llm_pos_ids[
|
||||
:, video_data_index : video_data_index + 1
|
||||
]
|
||||
)
|
||||
video_data_index += 1
|
||||
else:
|
||||
llm_pos_ids_list.append(
|
||||
audio_llm_pos_ids[
|
||||
:, audio_data_index : audio_data_index + 1
|
||||
]
|
||||
)
|
||||
audio_data_index += 1
|
||||
if video_data_index < video_llm_pos_ids.shape[-1]:
|
||||
llm_pos_ids_list.append(
|
||||
video_llm_pos_ids[
|
||||
:, video_data_index : video_llm_pos_ids.shape[-1]
|
||||
]
|
||||
)
|
||||
if audio_data_index < audio_llm_pos_ids.shape[-1]:
|
||||
llm_pos_ids_list.append(
|
||||
audio_llm_pos_ids[
|
||||
:, audio_data_index : audio_llm_pos_ids.shape[-1]
|
||||
]
|
||||
)
|
||||
video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2)
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
eos_len = 1
|
||||
eos_block = (
|
||||
torch.arange(eos_len, device=device, dtype=torch.long)
|
||||
.view(1, -1)
|
||||
.expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
llm_pos_ids_list.append(eos_block)
|
||||
llm_pos_ids_list.append(eos_block)
|
||||
st += text_len + bos_len * 2 + audio_len + video_len + eos_len * 2 # noqa: E501
|
||||
audio_idx += 1
|
||||
video_idx += 1
|
||||
remain_videos -= 1
|
||||
remain_audios -= 1
|
||||
|
||||
if st < len(input_tokens):
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
text_len = len(input_tokens) - st
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len, device=device, dtype=torch.long)
|
||||
.view(1, -1)
|
||||
.expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
|
||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||
if llm_positions.shape[1] != seq_len:
|
||||
raise RuntimeError("Position ids length mismatch with input ids length")
|
||||
|
||||
position_ids = llm_positions.to(device=device, dtype=dtype)
|
||||
mrope_position_delta = llm_positions.max() + 1 - seq_len
|
||||
return position_ids, mrope_position_delta
|
||||
|
||||
@classmethod
|
||||
def _omni_get_input_positions_tensor(
|
||||
cls,
|
||||
@ -1168,7 +1484,38 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
# TODO(fyabc): refactor and share more code with
|
||||
# _vl_get_input_positions_tensor.
|
||||
|
||||
model_type = hf_config.model_type
|
||||
thinker_config = hf_config.thinker_config
|
||||
|
||||
if isinstance(image_grid_thw, list):
|
||||
image_grid_thw = torch.tensor(image_grid_thw)
|
||||
if isinstance(video_grid_thw, list):
|
||||
video_grid_thw = torch.tensor(video_grid_thw)
|
||||
|
||||
if "qwen3_omni" in model_type:
|
||||
input_tensor = torch.tensor(input_tokens)
|
||||
audio_lengths_tensor = audio_feature_lengths
|
||||
if audio_lengths_tensor is not None and not isinstance(
|
||||
audio_lengths_tensor, torch.Tensor
|
||||
):
|
||||
audio_lengths_tensor = torch.as_tensor(
|
||||
audio_lengths_tensor, dtype=torch.long
|
||||
)
|
||||
second_per_grids_tensor = (
|
||||
torch.tensor(second_per_grid_ts) if second_per_grid_ts else None
|
||||
)
|
||||
|
||||
llm_positions, mrope_position_delta = cls._omni3_get_input_positions_tensor( # noqa: E501
|
||||
thinker_config,
|
||||
input_tensor,
|
||||
image_grid_thw,
|
||||
video_grid_thw,
|
||||
use_audio_in_video,
|
||||
audio_lengths_tensor,
|
||||
second_per_grids_tensor,
|
||||
)
|
||||
return llm_positions, mrope_position_delta
|
||||
|
||||
audio_token_id = thinker_config.audio_token_index
|
||||
image_token_id = thinker_config.image_token_index
|
||||
video_token_id = thinker_config.video_token_index
|
||||
@ -1182,11 +1529,6 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
thinker_config.vision_config, "tokens_per_second", 25
|
||||
)
|
||||
|
||||
if isinstance(image_grid_thw, list):
|
||||
image_grid_thw = torch.tensor(image_grid_thw)
|
||||
if isinstance(video_grid_thw, list):
|
||||
video_grid_thw = torch.tensor(video_grid_thw)
|
||||
|
||||
src_item = input_tokens
|
||||
audio_seqlens = audio_feature_lengths
|
||||
if not second_per_grid_ts:
|
||||
@ -1232,7 +1574,7 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
grid_t = image_grid_thw[image_idx][0]
|
||||
grid_hs = image_grid_thw[:, 1]
|
||||
grid_ws = image_grid_thw[:, 2]
|
||||
t_index = (torch.arange(grid_t) * 1 * tokens_per_second).long()
|
||||
t_index = torch.arange(grid_t) * 1 * tokens_per_second
|
||||
llm_pos_ids = cls._get_llm_pos_ids_for_vision(
|
||||
start_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws
|
||||
)
|
||||
@ -1250,7 +1592,7 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
torch.arange(grid_t)
|
||||
* second_per_grid_ts[video_idx]
|
||||
* tokens_per_second
|
||||
).long()
|
||||
)
|
||||
llm_pos_ids = cls._get_llm_pos_ids_for_vision(
|
||||
start_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
|
||||
)
|
||||
@ -1277,7 +1619,7 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
torch.arange(grid_t)
|
||||
* second_per_grid_ts[video_idx]
|
||||
* tokens_per_second
|
||||
).long()
|
||||
)
|
||||
t_index_split_chunk = cls._split_list_into_ranges(
|
||||
t_index, t_ntoken_per_chunk
|
||||
)
|
||||
@ -1452,9 +1794,7 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
grid_h = video_grid_thw[1]
|
||||
grid_w = video_grid_thw[2]
|
||||
t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk)
|
||||
t_index = (
|
||||
torch.arange(grid_t) * video_second_per_grid_t * tokens_per_second
|
||||
).long()
|
||||
t_index = torch.arange(grid_t) * video_second_per_grid_t * tokens_per_second
|
||||
t_index_split_chunk = cls._split_list_into_ranges(t_index, t_ntoken_per_chunk)
|
||||
|
||||
updates = [audio_start_token_id]
|
||||
|
1409
vllm/model_executor/models/qwen3_omni_moe_thinker.py
Executable file
1409
vllm/model_executor/models/qwen3_omni_moe_thinker.py
Executable file
File diff suppressed because it is too large
Load Diff
@ -355,6 +355,10 @@ _MULTIMODAL_MODELS = {
|
||||
"qwen2_5_omni_thinker",
|
||||
"Qwen2_5OmniThinkerForConditionalGeneration",
|
||||
),
|
||||
"Qwen3OmniMoeForConditionalGeneration": (
|
||||
"qwen3_omni_moe_thinker",
|
||||
"Qwen3OmniMoeThinkerForConditionalGeneration",
|
||||
),
|
||||
"Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"), # noqa: E501
|
||||
"Qwen3VLMoeForConditionalGeneration": (
|
||||
"qwen3_vl_moe",
|
||||
|
Reference in New Issue
Block a user