mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[Misc] Qwen2.5-VL Optimization (#13155)
This commit is contained in:
@ -45,6 +45,7 @@ from vllm.distributed import utils as dist_utils
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor import SamplingMetadata
|
from vllm.model_executor import SamplingMetadata
|
||||||
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
|
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
|
||||||
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
@ -271,8 +272,13 @@ class Qwen2_5_VisionAttention(nn.Module):
|
|||||||
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
|
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
|
||||||
for x in (q, k, v))
|
for x in (q, k, v))
|
||||||
if rotary_pos_emb is not None:
|
if rotary_pos_emb is not None:
|
||||||
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
|
use_flash_attn = self.attn_backend == _Backend.FLASH_ATTN
|
||||||
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
|
q = apply_rotary_pos_emb_vision(q,
|
||||||
|
rotary_pos_emb,
|
||||||
|
use_flash_attn=use_flash_attn)
|
||||||
|
k = apply_rotary_pos_emb_vision(k,
|
||||||
|
rotary_pos_emb,
|
||||||
|
use_flash_attn=use_flash_attn)
|
||||||
|
|
||||||
if self.attn_backend == _Backend.FLASH_ATTN:
|
if self.attn_backend == _Backend.FLASH_ATTN:
|
||||||
# from vllm_flash_attn.flash_attn_interface import (
|
# from vllm_flash_attn.flash_attn_interface import (
|
||||||
@ -296,20 +302,23 @@ class Qwen2_5_VisionAttention(nn.Module):
|
|||||||
"(b s) ... -> b s ...",
|
"(b s) ... -> b s ...",
|
||||||
b=batch_size)
|
b=batch_size)
|
||||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||||
seq_length = q.size(1)
|
# Execute attention entry by entry for speed & less VRAM.
|
||||||
q, k, v = (rearrange(x, "b s h d -> b h s d") for x in [q, k, v])
|
outputs = []
|
||||||
attention_mask = torch.zeros([1, seq_length, seq_length],
|
|
||||||
device=q.device,
|
|
||||||
dtype=torch.bool)
|
|
||||||
for i in range(1, len(cu_seqlens)):
|
for i in range(1, len(cu_seqlens)):
|
||||||
attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i],
|
start_idx = cu_seqlens[i - 1]
|
||||||
cu_seqlens[i - 1]:cu_seqlens[i]] = True
|
end_idx = cu_seqlens[i]
|
||||||
output = F.scaled_dot_product_attention(q,
|
q_i = q[:, start_idx:end_idx]
|
||||||
k,
|
k_i = k[:, start_idx:end_idx]
|
||||||
v,
|
v_i = v[:, start_idx:end_idx]
|
||||||
attention_mask,
|
q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d")
|
||||||
dropout_p=0.0)
|
for x in [q_i, k_i, v_i])
|
||||||
context_layer = rearrange(output, "b h s d -> b s h d ")
|
output_i = F.scaled_dot_product_attention(q_i,
|
||||||
|
k_i,
|
||||||
|
v_i,
|
||||||
|
dropout_p=0.0)
|
||||||
|
output_i = rearrange(output_i, "b h s d -> b s h d ")
|
||||||
|
outputs.append(output_i)
|
||||||
|
context_layer = torch.cat(outputs, dim=1)
|
||||||
elif self.attn_backend == _Backend.XFORMERS:
|
elif self.attn_backend == _Backend.XFORMERS:
|
||||||
from xformers import ops as xops
|
from xformers import ops as xops
|
||||||
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
||||||
@ -327,25 +336,6 @@ class Qwen2_5_VisionAttention(nn.Module):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
class Qwen2RMSNorm(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
|
||||||
super().__init__()
|
|
||||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
|
||||||
self.variance_epsilon = eps
|
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
|
||||||
input_dtype = hidden_states.dtype
|
|
||||||
hidden_states = hidden_states.to(torch.float32)
|
|
||||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
||||||
hidden_states = hidden_states * torch.rsqrt(variance +
|
|
||||||
self.variance_epsilon)
|
|
||||||
return self.weight * hidden_states.to(input_dtype)
|
|
||||||
|
|
||||||
def extra_repr(self):
|
|
||||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
|
||||||
|
|
||||||
|
|
||||||
class Qwen2_5_VisionBlock(nn.Module):
|
class Qwen2_5_VisionBlock(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -516,8 +506,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
# NOTE: We use torch native RMSNorm here for precision purposes.
|
norm_layer = partial(RMSNorm, eps=norm_eps)
|
||||||
norm_layer = partial(Qwen2RMSNorm, eps=norm_eps)
|
|
||||||
head_dim = self.hidden_size // self.num_heads
|
head_dim = self.hidden_size // self.num_heads
|
||||||
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
|
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
|
||||||
|
|
||||||
|
|||||||
@ -226,11 +226,15 @@ def apply_rotary_emb_torch(x: torch.Tensor,
|
|||||||
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb_vision(t: torch.Tensor,
|
def apply_rotary_pos_emb_vision(t: torch.Tensor,
|
||||||
freqs: torch.Tensor) -> torch.Tensor:
|
freqs: torch.Tensor,
|
||||||
|
use_flash_attn=False) -> torch.Tensor:
|
||||||
t_ = t.float()
|
t_ = t.float()
|
||||||
cos = freqs.cos()
|
cos = freqs.cos()
|
||||||
sin = freqs.sin()
|
sin = freqs.sin()
|
||||||
output = apply_rotary_emb_torch(t_, cos, sin).type_as(t)
|
apply_rotary_emb = apply_rotary_emb_torch
|
||||||
|
if use_flash_attn:
|
||||||
|
from flash_attn.layers.rotary import apply_rotary_emb
|
||||||
|
output = apply_rotary_emb(t_, cos, sin).type_as(t)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
@ -336,20 +340,23 @@ class Qwen2VisionAttention(nn.Module):
|
|||||||
"(b s) ... -> b s ...",
|
"(b s) ... -> b s ...",
|
||||||
b=batch_size)
|
b=batch_size)
|
||||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||||
seq_length = q.size(1)
|
# Execute attention entry by entry for speed & less VRAM.
|
||||||
q, k, v = (rearrange(x, "b s h d -> b h s d") for x in [q, k, v])
|
outputs = []
|
||||||
attention_mask = torch.zeros([1, seq_length, seq_length],
|
|
||||||
device=q.device,
|
|
||||||
dtype=torch.bool)
|
|
||||||
for i in range(1, len(cu_seqlens)):
|
for i in range(1, len(cu_seqlens)):
|
||||||
attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i],
|
start_idx = cu_seqlens[i - 1]
|
||||||
cu_seqlens[i - 1]:cu_seqlens[i]] = True
|
end_idx = cu_seqlens[i]
|
||||||
output = F.scaled_dot_product_attention(q,
|
q_i = q[:, start_idx:end_idx]
|
||||||
k,
|
k_i = k[:, start_idx:end_idx]
|
||||||
v,
|
v_i = v[:, start_idx:end_idx]
|
||||||
attention_mask,
|
q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d")
|
||||||
dropout_p=0.0)
|
for x in [q_i, k_i, v_i])
|
||||||
context_layer = rearrange(output, "b h s d -> b s h d ")
|
output_i = F.scaled_dot_product_attention(q_i,
|
||||||
|
k_i,
|
||||||
|
v_i,
|
||||||
|
dropout_p=0.0)
|
||||||
|
output_i = rearrange(output_i, "b h s d -> b s h d ")
|
||||||
|
outputs.append(output_i)
|
||||||
|
context_layer = torch.cat(outputs, dim=1)
|
||||||
elif self.attn_backend == _Backend.XFORMERS:
|
elif self.attn_backend == _Backend.XFORMERS:
|
||||||
from xformers import ops as xops
|
from xformers import ops as xops
|
||||||
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
||||||
|
|||||||
Reference in New Issue
Block a user