[Bugfix][Perf] Misc fixes for Qwen3 VL (#25238)

Signed-off-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Roger Wang
2025-09-19 03:46:16 -07:00
committed by GitHub
parent cea91a32f2
commit 1dfea5f4a9
2 changed files with 12 additions and 13 deletions

View File

@ -1075,6 +1075,8 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
config.text_config.hidden_size) config.text_config.hidden_size)
for _ in range(self.deepstack_num_level) for _ in range(self.deepstack_num_level)
] if self.use_deepstack else None ] if self.use_deepstack else None
self.visual_dim = config.vision_config.out_hidden_size
self.multiscale_dim = self.visual_dim * self.deepstack_num_level
def _get_deepstack_input_embeds(self, def _get_deepstack_input_embeds(self,
num_tokens: int) -> IntermediateTensors: num_tokens: int) -> IntermediateTensors:
@ -1313,12 +1315,8 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
] ]
multimodal_embeddings_cat = torch.cat(multimodal_embeddings, dim=0) multimodal_embeddings_cat = torch.cat(multimodal_embeddings, dim=0)
visual_dim = multimodal_embeddings_cat.shape[-1] // (
self.deepstack_num_level + 1)
main_dim, multi_dim = visual_dim, visual_dim * self.deepstack_num_level
multimodal_embeddings_main, multimodal_embeddings_multiscale = torch.split( # noqa:E501 multimodal_embeddings_main, multimodal_embeddings_multiscale = torch.split( # noqa:E501
multimodal_embeddings_cat, [main_dim, multi_dim], multimodal_embeddings_cat, [self.visual_dim, self.multiscale_dim],
dim=-1) dim=-1)
multimodal_embeddings = torch.split(multimodal_embeddings_main, multimodal_embeddings = torch.split(multimodal_embeddings_main,
@ -1340,10 +1338,8 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
], ],
) )
deepstack_input_embeds = deepstack_input_embeds.view( deepstack_input_embeds = deepstack_input_embeds.view(
inputs_embeds.shape[0], self.deepstack_num_level, inputs_embeds.shape[0], self.deepstack_num_level, self.visual_dim)
visual_dim).contiguous() deepstack_input_embeds = deepstack_input_embeds.permute(1, 0, 2)
deepstack_input_embeds = deepstack_input_embeds.permute(
1, 0, 2).contiguous()
return deepstack_input_embeds, multimodal_embeddings return deepstack_input_embeds, multimodal_embeddings
def get_input_embeddings( def get_input_embeddings(
@ -1353,9 +1349,10 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
) -> torch.Tensor: ) -> torch.Tensor:
deepstack_input_embeds = None deepstack_input_embeds = None
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None and self.use_deepstack: if multimodal_embeddings is not None:
deepstack_input_embeds, multimodal_embeddings = self._compute_deepstack_embeds( # noqa:E501 if self.use_deepstack:
input_ids, inputs_embeds, multimodal_embeddings) deepstack_input_embeds, multimodal_embeddings = self._compute_deepstack_embeds( # noqa:E501
input_ids, inputs_embeds, multimodal_embeddings)
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, input_ids, inputs_embeds, multimodal_embeddings,
[self.config.image_token_id, self.config.video_token_id]) [self.config.image_token_id, self.config.video_token_id])
@ -1531,4 +1528,4 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
language_model="language_model", language_model="language_model",
connector="model.visual.merger", connector="model.visual.merger",
tower_model="model.visual.", tower_model="model.visual.",
) )

View File

@ -344,3 +344,5 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
config.text_config.hidden_size) config.text_config.hidden_size)
for _ in range(self.deepstack_num_level) for _ in range(self.deepstack_num_level)
] if self.use_deepstack else None ] if self.use_deepstack else None
self.visual_dim = config.vision_config.out_hidden_size
self.multiscale_dim = self.visual_dim * self.deepstack_num_level