mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[Bugfix][Perf] Misc fixes for Qwen3 VL (#25238)
Signed-off-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
@ -1075,6 +1075,8 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
config.text_config.hidden_size)
|
config.text_config.hidden_size)
|
||||||
for _ in range(self.deepstack_num_level)
|
for _ in range(self.deepstack_num_level)
|
||||||
] if self.use_deepstack else None
|
] if self.use_deepstack else None
|
||||||
|
self.visual_dim = config.vision_config.out_hidden_size
|
||||||
|
self.multiscale_dim = self.visual_dim * self.deepstack_num_level
|
||||||
|
|
||||||
def _get_deepstack_input_embeds(self,
|
def _get_deepstack_input_embeds(self,
|
||||||
num_tokens: int) -> IntermediateTensors:
|
num_tokens: int) -> IntermediateTensors:
|
||||||
@ -1313,12 +1315,8 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
]
|
]
|
||||||
multimodal_embeddings_cat = torch.cat(multimodal_embeddings, dim=0)
|
multimodal_embeddings_cat = torch.cat(multimodal_embeddings, dim=0)
|
||||||
|
|
||||||
visual_dim = multimodal_embeddings_cat.shape[-1] // (
|
|
||||||
self.deepstack_num_level + 1)
|
|
||||||
|
|
||||||
main_dim, multi_dim = visual_dim, visual_dim * self.deepstack_num_level
|
|
||||||
multimodal_embeddings_main, multimodal_embeddings_multiscale = torch.split( # noqa:E501
|
multimodal_embeddings_main, multimodal_embeddings_multiscale = torch.split( # noqa:E501
|
||||||
multimodal_embeddings_cat, [main_dim, multi_dim],
|
multimodal_embeddings_cat, [self.visual_dim, self.multiscale_dim],
|
||||||
dim=-1)
|
dim=-1)
|
||||||
|
|
||||||
multimodal_embeddings = torch.split(multimodal_embeddings_main,
|
multimodal_embeddings = torch.split(multimodal_embeddings_main,
|
||||||
@ -1340,10 +1338,8 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
deepstack_input_embeds = deepstack_input_embeds.view(
|
deepstack_input_embeds = deepstack_input_embeds.view(
|
||||||
inputs_embeds.shape[0], self.deepstack_num_level,
|
inputs_embeds.shape[0], self.deepstack_num_level, self.visual_dim)
|
||||||
visual_dim).contiguous()
|
deepstack_input_embeds = deepstack_input_embeds.permute(1, 0, 2)
|
||||||
deepstack_input_embeds = deepstack_input_embeds.permute(
|
|
||||||
1, 0, 2).contiguous()
|
|
||||||
return deepstack_input_embeds, multimodal_embeddings
|
return deepstack_input_embeds, multimodal_embeddings
|
||||||
|
|
||||||
def get_input_embeddings(
|
def get_input_embeddings(
|
||||||
@ -1353,9 +1349,10 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
deepstack_input_embeds = None
|
deepstack_input_embeds = None
|
||||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||||
if multimodal_embeddings is not None and self.use_deepstack:
|
if multimodal_embeddings is not None:
|
||||||
deepstack_input_embeds, multimodal_embeddings = self._compute_deepstack_embeds( # noqa:E501
|
if self.use_deepstack:
|
||||||
input_ids, inputs_embeds, multimodal_embeddings)
|
deepstack_input_embeds, multimodal_embeddings = self._compute_deepstack_embeds( # noqa:E501
|
||||||
|
input_ids, inputs_embeds, multimodal_embeddings)
|
||||||
inputs_embeds = merge_multimodal_embeddings(
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
input_ids, inputs_embeds, multimodal_embeddings,
|
input_ids, inputs_embeds, multimodal_embeddings,
|
||||||
[self.config.image_token_id, self.config.video_token_id])
|
[self.config.image_token_id, self.config.video_token_id])
|
||||||
@ -1531,4 +1528,4 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
language_model="language_model",
|
language_model="language_model",
|
||||||
connector="model.visual.merger",
|
connector="model.visual.merger",
|
||||||
tower_model="model.visual.",
|
tower_model="model.visual.",
|
||||||
)
|
)
|
@ -344,3 +344,5 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
|
|||||||
config.text_config.hidden_size)
|
config.text_config.hidden_size)
|
||||||
for _ in range(self.deepstack_num_level)
|
for _ in range(self.deepstack_num_level)
|
||||||
] if self.use_deepstack else None
|
] if self.use_deepstack else None
|
||||||
|
self.visual_dim = config.vision_config.out_hidden_size
|
||||||
|
self.multiscale_dim = self.visual_dim * self.deepstack_num_level
|
Reference in New Issue
Block a user