[Misc] Qwen2.5-VL Optimization (#13155)

This commit is contained in:
2025-02-13 22:17:57 +08:00
committed by GitHub
parent 2092a6fa7d
commit 02ed8a1fbe
2 changed files with 47 additions and 51 deletions

View File

@ -45,6 +45,7 @@ from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
@ -271,8 +272,13 @@ class Qwen2_5_VisionAttention(nn.Module):
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
for x in (q, k, v))
if rotary_pos_emb is not None:
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
use_flash_attn = self.attn_backend == _Backend.FLASH_ATTN
q = apply_rotary_pos_emb_vision(q,
rotary_pos_emb,
use_flash_attn=use_flash_attn)
k = apply_rotary_pos_emb_vision(k,
rotary_pos_emb,
use_flash_attn=use_flash_attn)
if self.attn_backend == _Backend.FLASH_ATTN:
# from vllm_flash_attn.flash_attn_interface import (
@ -296,20 +302,23 @@ class Qwen2_5_VisionAttention(nn.Module):
"(b s) ... -> b s ...",
b=batch_size)
elif self.attn_backend == _Backend.TORCH_SDPA:
seq_length = q.size(1)
q, k, v = (rearrange(x, "b s h d -> b h s d") for x in [q, k, v])
attention_mask = torch.zeros([1, seq_length, seq_length],
device=q.device,
dtype=torch.bool)
# Execute attention entry by entry for speed & less VRAM.
outputs = []
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i],
cu_seqlens[i - 1]:cu_seqlens[i]] = True
output = F.scaled_dot_product_attention(q,
k,
v,
attention_mask,
start_idx = cu_seqlens[i - 1]
end_idx = cu_seqlens[i]
q_i = q[:, start_idx:end_idx]
k_i = k[:, start_idx:end_idx]
v_i = v[:, start_idx:end_idx]
q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d")
for x in [q_i, k_i, v_i])
output_i = F.scaled_dot_product_attention(q_i,
k_i,
v_i,
dropout_p=0.0)
context_layer = rearrange(output, "b h s d -> b s h d ")
output_i = rearrange(output_i, "b h s d -> b s h d ")
outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1)
elif self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
@ -327,25 +336,6 @@ class Qwen2_5_VisionAttention(nn.Module):
return output
class Qwen2RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance +
self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class Qwen2_5_VisionBlock(nn.Module):
def __init__(
@ -516,8 +506,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
hidden_size=self.hidden_size,
)
# NOTE: We use torch native RMSNorm here for precision purposes.
norm_layer = partial(Qwen2RMSNorm, eps=norm_eps)
norm_layer = partial(RMSNorm, eps=norm_eps)
head_dim = self.hidden_size // self.num_heads
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)

View File

@ -226,11 +226,15 @@ def apply_rotary_emb_torch(x: torch.Tensor,
def apply_rotary_pos_emb_vision(t: torch.Tensor,
freqs: torch.Tensor) -> torch.Tensor:
freqs: torch.Tensor,
use_flash_attn=False) -> torch.Tensor:
t_ = t.float()
cos = freqs.cos()
sin = freqs.sin()
output = apply_rotary_emb_torch(t_, cos, sin).type_as(t)
apply_rotary_emb = apply_rotary_emb_torch
if use_flash_attn:
from flash_attn.layers.rotary import apply_rotary_emb
output = apply_rotary_emb(t_, cos, sin).type_as(t)
return output
@ -336,20 +340,23 @@ class Qwen2VisionAttention(nn.Module):
"(b s) ... -> b s ...",
b=batch_size)
elif self.attn_backend == _Backend.TORCH_SDPA:
seq_length = q.size(1)
q, k, v = (rearrange(x, "b s h d -> b h s d") for x in [q, k, v])
attention_mask = torch.zeros([1, seq_length, seq_length],
device=q.device,
dtype=torch.bool)
# Execute attention entry by entry for speed & less VRAM.
outputs = []
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i],
cu_seqlens[i - 1]:cu_seqlens[i]] = True
output = F.scaled_dot_product_attention(q,
k,
v,
attention_mask,
start_idx = cu_seqlens[i - 1]
end_idx = cu_seqlens[i]
q_i = q[:, start_idx:end_idx]
k_i = k[:, start_idx:end_idx]
v_i = v[:, start_idx:end_idx]
q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d")
for x in [q_i, k_i, v_i])
output_i = F.scaled_dot_product_attention(q_i,
k_i,
v_i,
dropout_p=0.0)
context_layer = rearrange(output, "b h s d -> b s h d ")
output_i = rearrange(output_i, "b h s d -> b s h d ")
outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1)
elif self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask