mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[Misc] Qwen2.5-VL Optimization (#13155)
This commit is contained in:
@ -45,6 +45,7 @@ from vllm.distributed import utils as dist_utils
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
@ -271,8 +272,13 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
|
||||
for x in (q, k, v))
|
||||
if rotary_pos_emb is not None:
|
||||
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
|
||||
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
|
||||
use_flash_attn = self.attn_backend == _Backend.FLASH_ATTN
|
||||
q = apply_rotary_pos_emb_vision(q,
|
||||
rotary_pos_emb,
|
||||
use_flash_attn=use_flash_attn)
|
||||
k = apply_rotary_pos_emb_vision(k,
|
||||
rotary_pos_emb,
|
||||
use_flash_attn=use_flash_attn)
|
||||
|
||||
if self.attn_backend == _Backend.FLASH_ATTN:
|
||||
# from vllm_flash_attn.flash_attn_interface import (
|
||||
@ -296,20 +302,23 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
"(b s) ... -> b s ...",
|
||||
b=batch_size)
|
||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||
seq_length = q.size(1)
|
||||
q, k, v = (rearrange(x, "b s h d -> b h s d") for x in [q, k, v])
|
||||
attention_mask = torch.zeros([1, seq_length, seq_length],
|
||||
device=q.device,
|
||||
dtype=torch.bool)
|
||||
# Execute attention entry by entry for speed & less VRAM.
|
||||
outputs = []
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i],
|
||||
cu_seqlens[i - 1]:cu_seqlens[i]] = True
|
||||
output = F.scaled_dot_product_attention(q,
|
||||
k,
|
||||
v,
|
||||
attention_mask,
|
||||
start_idx = cu_seqlens[i - 1]
|
||||
end_idx = cu_seqlens[i]
|
||||
q_i = q[:, start_idx:end_idx]
|
||||
k_i = k[:, start_idx:end_idx]
|
||||
v_i = v[:, start_idx:end_idx]
|
||||
q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d")
|
||||
for x in [q_i, k_i, v_i])
|
||||
output_i = F.scaled_dot_product_attention(q_i,
|
||||
k_i,
|
||||
v_i,
|
||||
dropout_p=0.0)
|
||||
context_layer = rearrange(output, "b h s d -> b s h d ")
|
||||
output_i = rearrange(output_i, "b h s d -> b s h d ")
|
||||
outputs.append(output_i)
|
||||
context_layer = torch.cat(outputs, dim=1)
|
||||
elif self.attn_backend == _Backend.XFORMERS:
|
||||
from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
||||
@ -327,25 +336,6 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
return output
|
||||
|
||||
|
||||
class Qwen2RMSNorm(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance +
|
||||
self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||
|
||||
|
||||
class Qwen2_5_VisionBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
@ -516,8 +506,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
hidden_size=self.hidden_size,
|
||||
)
|
||||
|
||||
# NOTE: We use torch native RMSNorm here for precision purposes.
|
||||
norm_layer = partial(Qwen2RMSNorm, eps=norm_eps)
|
||||
norm_layer = partial(RMSNorm, eps=norm_eps)
|
||||
head_dim = self.hidden_size // self.num_heads
|
||||
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
|
||||
|
||||
|
@ -226,11 +226,15 @@ def apply_rotary_emb_torch(x: torch.Tensor,
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_vision(t: torch.Tensor,
|
||||
freqs: torch.Tensor) -> torch.Tensor:
|
||||
freqs: torch.Tensor,
|
||||
use_flash_attn=False) -> torch.Tensor:
|
||||
t_ = t.float()
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
output = apply_rotary_emb_torch(t_, cos, sin).type_as(t)
|
||||
apply_rotary_emb = apply_rotary_emb_torch
|
||||
if use_flash_attn:
|
||||
from flash_attn.layers.rotary import apply_rotary_emb
|
||||
output = apply_rotary_emb(t_, cos, sin).type_as(t)
|
||||
return output
|
||||
|
||||
|
||||
@ -336,20 +340,23 @@ class Qwen2VisionAttention(nn.Module):
|
||||
"(b s) ... -> b s ...",
|
||||
b=batch_size)
|
||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||
seq_length = q.size(1)
|
||||
q, k, v = (rearrange(x, "b s h d -> b h s d") for x in [q, k, v])
|
||||
attention_mask = torch.zeros([1, seq_length, seq_length],
|
||||
device=q.device,
|
||||
dtype=torch.bool)
|
||||
# Execute attention entry by entry for speed & less VRAM.
|
||||
outputs = []
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i],
|
||||
cu_seqlens[i - 1]:cu_seqlens[i]] = True
|
||||
output = F.scaled_dot_product_attention(q,
|
||||
k,
|
||||
v,
|
||||
attention_mask,
|
||||
start_idx = cu_seqlens[i - 1]
|
||||
end_idx = cu_seqlens[i]
|
||||
q_i = q[:, start_idx:end_idx]
|
||||
k_i = k[:, start_idx:end_idx]
|
||||
v_i = v[:, start_idx:end_idx]
|
||||
q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d")
|
||||
for x in [q_i, k_i, v_i])
|
||||
output_i = F.scaled_dot_product_attention(q_i,
|
||||
k_i,
|
||||
v_i,
|
||||
dropout_p=0.0)
|
||||
context_layer = rearrange(output, "b h s d -> b s h d ")
|
||||
output_i = rearrange(output_i, "b h s d -> b s h d ")
|
||||
outputs.append(output_i)
|
||||
context_layer = torch.cat(outputs, dim=1)
|
||||
elif self.attn_backend == _Backend.XFORMERS:
|
||||
from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
||||
|
Reference in New Issue
Block a user