[Bugfix] Allow Only SDPA Backend for ViT on B200 for Qwen3-VL (#25788)

This commit is contained in:
Wentao Ye
2025-09-26 23:44:52 -04:00
committed by GitHub
parent f1d53d150c
commit c242c98031
2 changed files with 75 additions and 51 deletions

View File

@ -274,6 +274,8 @@ class Qwen2_5_VisionAttention(nn.Module):
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend: _Backend = _Backend.TORCH_SDPA,
use_upstream_fa: bool = False,
) -> None:
super().__init__()
# Per attention head and per partition values.
@ -300,25 +302,8 @@ class Qwen2_5_VisionAttention(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.proj",
disable_tp=use_data_parallel)
# Detect attention implementation.
self.attn_backend = get_vit_attn_backend(
head_size=self.hidden_size_per_attention_head,
dtype=torch.get_default_dtype())
self.use_upstream_fa = False
if self.attn_backend != _Backend.FLASH_ATTN and \
check_upstream_fa_availability(
torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN
self.use_upstream_fa = True
if self.attn_backend not in {
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
_Backend.ROCM_AITER_FA
}:
raise RuntimeError(
f"Qwen2.5-VL does not support {self.attn_backend} backend now."
)
self.attn_backend = attn_backend
self.use_upstream_fa = use_upstream_fa
self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
}
@ -443,6 +428,8 @@ class Qwen2_5_VisionBlock(nn.Module):
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend: _Backend = _Backend.TORCH_SDPA,
use_upstream_fa: bool = False,
) -> None:
super().__init__()
if norm_layer is None:
@ -455,7 +442,9 @@ class Qwen2_5_VisionBlock(nn.Module):
projection_size=dim,
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel)
use_data_parallel=use_data_parallel,
attn_backend=attn_backend,
use_upstream_fa=use_upstream_fa)
self.mlp = Qwen2_5_VisionMLP(dim,
mlp_hidden_dim,
act_fn=act_fn,
@ -627,17 +616,35 @@ class Qwen2_5_VisionTransformer(nn.Module):
head_dim = self.hidden_size // self.num_heads
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
use_upstream_fa = False
self.attn_backend = get_vit_attn_backend(
head_size=head_dim, dtype=torch.get_default_dtype())
if self.attn_backend != _Backend.FLASH_ATTN and \
check_upstream_fa_availability(
torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN
use_upstream_fa = True
if self.attn_backend not in {
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
_Backend.ROCM_AITER_FA
}:
raise RuntimeError(
f"Qwen2.5-VL does not support {self.attn_backend} backend now."
)
self.blocks = nn.ModuleList([
Qwen2_5_VisionBlock(dim=self.hidden_size,
num_heads=self.num_heads,
mlp_hidden_dim=vision_config.intermediate_size,
act_fn=get_act_and_mul_fn(
vision_config.hidden_act),
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=use_data_parallel)
for layer_idx in range(depth)
Qwen2_5_VisionBlock(
dim=self.hidden_size,
num_heads=self.num_heads,
mlp_hidden_dim=vision_config.intermediate_size,
act_fn=get_act_and_mul_fn(vision_config.hidden_act),
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=use_data_parallel,
attn_backend=self.attn_backend,
use_upstream_fa=use_upstream_fa) for layer_idx in range(depth)
])
self.merger = Qwen2_5_VisionPatchMerger(
d_model=vision_config.out_hidden_size,
@ -648,12 +655,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
prefix=f"{prefix}.merger",
use_data_parallel=use_data_parallel,
)
self.attn_backend = get_vit_attn_backend(
head_size=head_dim, dtype=torch.get_default_dtype())
if self.attn_backend != _Backend.FLASH_ATTN and \
check_upstream_fa_availability(
torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN
@property
def dtype(self) -> torch.dtype:

View File

@ -63,7 +63,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptReplacement, PromptUpdate,
PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import _Backend
from vllm.platforms import _Backend, current_platform
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import uses_mrope
from vllm.utils import is_list_of
@ -158,6 +158,8 @@ class Qwen3_VisionBlock(nn.Module):
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend: _Backend = _Backend.TORCH_SDPA,
use_upstream_fa: bool = False,
) -> None:
super().__init__()
if norm_layer is None:
@ -170,7 +172,9 @@ class Qwen3_VisionBlock(nn.Module):
projection_size=dim,
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel)
use_data_parallel=use_data_parallel,
attn_backend=attn_backend,
use_upstream_fa=use_upstream_fa)
self.mlp = Qwen3_VisionMLP(dim,
mlp_hidden_dim,
act_fn=act_fn,
@ -287,19 +291,6 @@ class Qwen3_VisionTransformer(nn.Module):
head_dim = self.hidden_size // self.num_heads
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
self.blocks = nn.ModuleList([
Qwen3_VisionBlock(
dim=self.hidden_size,
num_heads=self.num_heads,
mlp_hidden_dim=vision_config.intermediate_size,
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=use_data_parallel)
for layer_idx in range(vision_config.depth)
])
self.merger = Qwen3_VisionPatchMerger(
d_model=vision_config.out_hidden_size,
context_dim=self.hidden_size,
@ -325,10 +316,42 @@ class Qwen3_VisionTransformer(nn.Module):
self.attn_backend = get_vit_attn_backend(
head_size=head_dim, dtype=torch.get_default_dtype())
use_upstream_fa = False
if self.attn_backend != _Backend.FLASH_ATTN and \
check_upstream_fa_availability(
torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN
use_upstream_fa = True
if self.attn_backend not in {
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
_Backend.ROCM_AITER_FA
}:
raise RuntimeError(
f"Qwen3-VL does not support {self.attn_backend} backend now.")
if current_platform.is_device_capability(
100) and self.attn_backend != _Backend.TORCH_SDPA:
# TODO(Roger/Wentao): remove this after FA
# or XFORMERS's issue fixed on Blackwell
logger.info_once("Qwen3-VL vision attention does not support "
f"{self.attn_backend} backend on Blackwell now. "
"Vision attention backend is set to TORCH_SDPA.")
self.attn_backend = _Backend.TORCH_SDPA
self.blocks = nn.ModuleList([
Qwen3_VisionBlock(
dim=self.hidden_size,
num_heads=self.num_heads,
mlp_hidden_dim=vision_config.intermediate_size,
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=use_data_parallel,
attn_backend=self.attn_backend,
use_upstream_fa=use_upstream_fa)
for layer_idx in range(vision_config.depth)
])
@property
def dtype(self) -> torch.dtype: