mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Model] Consolidate ViTs attention implementation without mask (#10893)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.attention import AttentionMetadata, AttentionType
|
||||
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
|
||||
@ -168,6 +169,68 @@ class Attention(nn.Module):
|
||||
return s
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
"""Multi-headed attention without any cache, used for ViT."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = scale
|
||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||
|
||||
dtype = torch.get_default_dtype()
|
||||
attn_backend = get_attn_backend(head_size,
|
||||
dtype,
|
||||
kv_cache_dtype=None,
|
||||
block_size=16,
|
||||
is_attention_free=False)
|
||||
if attn_backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
|
||||
attn_backend = _Backend.XFORMERS
|
||||
|
||||
self.attn_backend = attn_backend if attn_backend in {
|
||||
_Backend.TORCH_SDPA, _Backend.XFORMERS
|
||||
} else _Backend.TORCH_SDPA
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Input shape: batch_size x seq_len x hidden_size"""
|
||||
# TODO(Isotr0py): Use existing backend implementations and support FA2
|
||||
bsz, q_len, _ = query.size()
|
||||
kv_len = key.size(1)
|
||||
|
||||
query = query.view(bsz, q_len, self.num_heads, self.head_size)
|
||||
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
|
||||
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)
|
||||
|
||||
if self.attn_backend == _Backend.XFORMERS:
|
||||
from xformers import ops as xops
|
||||
|
||||
out = xops.memory_efficient_attention_forward(query,
|
||||
key,
|
||||
value,
|
||||
scale=self.scale)
|
||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||
query, key, value = (x.transpose(1, 2)
|
||||
for x in (query, key, value))
|
||||
out = F.scaled_dot_product_attention(query,
|
||||
key,
|
||||
value,
|
||||
scale=self.scale)
|
||||
out = out.transpose(1, 2)
|
||||
return out.view(bsz, q_len, -1)
|
||||
|
||||
|
||||
def unified_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
|
@ -4,11 +4,10 @@ from typing import Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
from transformers import Blip2VisionConfig, BlipVisionConfig
|
||||
|
||||
from vllm.attention.selector import _Backend
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.inputs import DecoderOnlyInputs, token_inputs
|
||||
@ -22,8 +21,6 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
|
||||
repeat_and_pad_placeholder_tokens)
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
from .utils import get_vit_attn_backend
|
||||
|
||||
|
||||
def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
|
||||
assert image_size % patch_size == 0
|
||||
@ -205,11 +202,8 @@ class BlipAttention(nn.Module):
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
||||
|
||||
# Detect attention implementation.
|
||||
self.attn_backend = get_vit_attn_backend(support_fa=False)
|
||||
if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}:
|
||||
raise RuntimeError(
|
||||
f"BLIP does not support {self.attn_backend} backend now.")
|
||||
self.attn = MultiHeadAttention(self.num_heads_per_partition,
|
||||
self.head_dim, self.scale)
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads,
|
||||
@ -220,41 +214,10 @@ class BlipAttention(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
):
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
qkv_states, _ = self.qkv(hidden_states)
|
||||
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
|
||||
query_states = query_states.view(bsz, tgt_len,
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim)
|
||||
key_states = key_states.view(bsz, tgt_len,
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim)
|
||||
value_states = value_states.view(bsz, tgt_len,
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim)
|
||||
|
||||
if self.attn_backend == _Backend.XFORMERS:
|
||||
from xformers import ops as xops
|
||||
|
||||
out = xops.memory_efficient_attention_forward(query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
p=self.dropout,
|
||||
scale=self.scale)
|
||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||
query_states, key_states, value_states = (x.transpose(1, 2)
|
||||
for x in (query_states,
|
||||
key_states,
|
||||
value_states))
|
||||
out = F.scaled_dot_product_attention(query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
dropout_p=self.dropout,
|
||||
scale=self.scale)
|
||||
out = out.transpose(1, 2)
|
||||
|
||||
out = out.view(bsz, tgt_len, -1)
|
||||
out = self.attn(query_states, key_states, value_states)
|
||||
attn_output, _ = self.projection(out)
|
||||
|
||||
return attn_output, None
|
||||
|
@ -5,11 +5,10 @@ from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
from transformers import CLIPVisionConfig
|
||||
|
||||
from vllm.attention.selector import _Backend
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.inputs import DecoderOnlyInputs, token_inputs
|
||||
@ -25,8 +24,6 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
|
||||
resolve_visual_encoder_outputs)
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
from .utils import get_vit_attn_backend
|
||||
|
||||
|
||||
def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
|
||||
assert image_size % patch_size == 0
|
||||
@ -235,11 +232,8 @@ class CLIPAttention(nn.Module):
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
||||
|
||||
# Detect attention implementation.
|
||||
self.attn_backend = get_vit_attn_backend(support_fa=False)
|
||||
if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}:
|
||||
raise RuntimeError(
|
||||
f"CLIP does not support {self.attn_backend} backend now.")
|
||||
self.attn = MultiHeadAttention(self.num_heads_per_partition,
|
||||
self.head_dim, self.scale)
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads,
|
||||
@ -250,42 +244,10 @@ class CLIPAttention(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
):
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
qkv_states, _ = self.qkv_proj(hidden_states)
|
||||
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
|
||||
|
||||
query_states = query_states.view(bsz, tgt_len,
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim)
|
||||
key_states = key_states.view(bsz, tgt_len,
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim)
|
||||
value_states = value_states.view(bsz, tgt_len,
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim)
|
||||
|
||||
if self.attn_backend == _Backend.XFORMERS:
|
||||
from xformers import ops as xops
|
||||
|
||||
out = xops.memory_efficient_attention_forward(query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
p=self.dropout,
|
||||
scale=self.scale)
|
||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||
query_states, key_states, value_states = (x.transpose(1, 2)
|
||||
for x in (query_states,
|
||||
key_states,
|
||||
value_states))
|
||||
out = F.scaled_dot_product_attention(query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
dropout_p=self.dropout,
|
||||
scale=self.scale)
|
||||
out = out.transpose(1, 2)
|
||||
|
||||
out = out.view(bsz, tgt_len, -1)
|
||||
out = self.attn(query_states, key_states, value_states)
|
||||
attn_output, _ = self.out_proj(out)
|
||||
|
||||
return attn_output, None
|
||||
|
@ -8,6 +8,7 @@ import torch
|
||||
from torch import nn
|
||||
from torch.nn import LayerNorm
|
||||
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
@ -77,27 +78,16 @@ class Attention(nn.Module):
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.attn = MultiHeadAttention(self.num_heads_per_rank, self.head_dim,
|
||||
self.scale)
|
||||
self.output_dropout = torch.nn.Dropout(config.dropout_prob)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
B, L, _ = x.shape
|
||||
qkv, _ = self.query_key_value(x) # B, L, 3 * H * D
|
||||
q, k, v = qkv.chunk(3, dim=-1)
|
||||
q = q.reshape(B, L, self.num_heads_per_rank,
|
||||
self.head_dim).permute(0, 2, 1, 3) # B, H, L, D
|
||||
k = k.reshape(B, L, self.num_heads_per_rank,
|
||||
self.head_dim).permute(0, 2, 1, 3) # B, H, L, D
|
||||
v = v.reshape(B, L, self.num_heads_per_rank,
|
||||
self.head_dim).permute(0, 2, 1, 3) # B, H, L, D
|
||||
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q,
|
||||
k,
|
||||
v,
|
||||
attn_mask=None,
|
||||
dropout_p=0.,
|
||||
is_causal=False)
|
||||
|
||||
output, _ = self.dense(out.transpose(1, 2).view(B, L, -1))
|
||||
out = self.attn(q, k, v)
|
||||
output, _ = self.dense(out)
|
||||
output = self.output_dropout(output)
|
||||
return output
|
||||
|
||||
|
@ -21,8 +21,8 @@ import torch
|
||||
from torch import nn
|
||||
from transformers.models.idefics2.configuration_idefics2 import (
|
||||
Idefics2Config, Idefics2VisionConfig)
|
||||
from xformers import ops as xops
|
||||
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
@ -141,35 +141,18 @@ class Idefics2VisionAttention(nn.Module):
|
||||
)
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
||||
self.is_causal = False
|
||||
self.attn = MultiHeadAttention(self.num_heads_per_partition,
|
||||
self.head_dim, self.scale)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
batch_size, q_len, _ = hidden_states.size()
|
||||
qkv, _ = self.qkv_proj(
|
||||
hidden_states
|
||||
) # batch_size, q_len, 3 * num_heads_per_partition * head_dim
|
||||
query_states, key_states, value_states = qkv.chunk(3, dim=-1)
|
||||
query_states = query_states.view(batch_size, q_len,
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim)
|
||||
key_states = key_states.view(batch_size, q_len,
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim)
|
||||
value_states = value_states.view(batch_size, q_len,
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim)
|
||||
# see: https://facebookresearch.github.io/xformers/components/ops.html
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
p=self.dropout,
|
||||
scale=self.scale,
|
||||
)
|
||||
out = out.view(batch_size, q_len, -1)
|
||||
out = self.attn(query_states, key_states, value_states)
|
||||
attn_output, _ = self.out_proj(out)
|
||||
return attn_output
|
||||
|
||||
|
@ -12,7 +12,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention.selector import _Backend
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
split_tensor_along_last_dim,
|
||||
@ -25,8 +25,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from .utils import get_vit_attn_backend
|
||||
|
||||
NORM2FN = {
|
||||
'rms_norm': RMSNorm,
|
||||
'layer_norm': nn.LayerNorm,
|
||||
@ -183,10 +181,8 @@ class InternParallelAttention(nn.Module):
|
||||
prefix=f"{prefix}.proj",
|
||||
)
|
||||
|
||||
self.attn_backend = get_vit_attn_backend(support_fa=False)
|
||||
if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}:
|
||||
raise RuntimeError(
|
||||
f"InternViT does not support {self.attn_backend} backend now.")
|
||||
self.attn = MultiHeadAttention(self.num_heads_per_partition,
|
||||
self.head_dim, self.scale)
|
||||
|
||||
def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
|
||||
if self.tp_size > 1:
|
||||
@ -209,23 +205,7 @@ class InternParallelAttention(nn.Module):
|
||||
if self.qk_normalization:
|
||||
q, k = self._apply_qk_norm(q, k)
|
||||
|
||||
q = q.view(B, N, self.num_heads_per_partition, self.head_dim)
|
||||
k = k.view(B, N, self.num_heads_per_partition, self.head_dim)
|
||||
v = v.view(B, N, self.num_heads_per_partition, self.head_dim)
|
||||
|
||||
if self.attn_backend == _Backend.XFORMERS:
|
||||
from xformers import ops as xops
|
||||
|
||||
out = xops.memory_efficient_attention_forward(q,
|
||||
k,
|
||||
v,
|
||||
scale=self.scale)
|
||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||
q, k, v = (x.transpose(1, 2) for x in (q, k, v))
|
||||
out = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
|
||||
out = out.transpose(1, 2)
|
||||
|
||||
out = out.view(B, N, -1)
|
||||
out = self.attn(q, k, v)
|
||||
out, _ = self.proj(out)
|
||||
return out
|
||||
|
||||
|
@ -482,6 +482,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self.mlp1 = self._init_mlp1(config)
|
||||
|
||||
self.img_context_token_id = None
|
||||
self.visual_token_mask = None
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
|
||||
@ -635,13 +636,12 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
return image_embeds
|
||||
|
||||
def _get_visual_token_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def _set_visual_token_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
if self.is_mono:
|
||||
visual_token_mask = (
|
||||
self.visual_token_mask = (
|
||||
input_ids == self.img_context_token_id).reshape(-1, 1)
|
||||
else:
|
||||
visual_token_mask = None
|
||||
return visual_token_mask
|
||||
self.visual_token_mask = None
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
@ -658,6 +658,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None:
|
||||
assert self.img_context_token_id is not None
|
||||
self._set_visual_token_mask(input_ids)
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, multimodal_embeddings,
|
||||
self.img_context_token_id)
|
||||
@ -674,7 +675,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
**kwargs: object,
|
||||
) -> Union[SamplerOutput, IntermediateTensors]:
|
||||
|
||||
visual_token_mask = None
|
||||
if intermediate_tensors is not None:
|
||||
input_ids = None
|
||||
inputs_embeds = None
|
||||
@ -695,16 +695,15 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
"intermediate_tensors": intermediate_tensors,
|
||||
"inputs_embeds": inputs_embeds,
|
||||
}
|
||||
if self.img_context_token_id is not None:
|
||||
visual_token_mask = self._get_visual_token_mask(input_ids)
|
||||
|
||||
# We always overwrite it back to None after computing visual token
|
||||
# mask so that this doesn't need to depend on encoder output
|
||||
if self.visual_token_mask is not None:
|
||||
# overwrite visual_token_mask and img_context_token_id back to None,
|
||||
# so that this doesn't need to depend on encoder output
|
||||
forward_kwargs.update(
|
||||
{"visual_token_mask": self.visual_token_mask})
|
||||
self.visual_token_mask = None
|
||||
self.img_context_token_id = None
|
||||
|
||||
if self.is_mono:
|
||||
forward_kwargs.update({"visual_token_mask": visual_token_mask})
|
||||
|
||||
hidden_states = self.language_model.model(**forward_kwargs)
|
||||
return hidden_states
|
||||
|
||||
|
@ -13,6 +13,7 @@ from torch.nn import functional as F
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
@ -38,14 +39,12 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||
from vllm.multimodal.inputs import NestedTensors
|
||||
from vllm.multimodal.utils import cached_get_tokenizer
|
||||
from vllm.platforms import _Backend
|
||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
|
||||
SequenceData)
|
||||
from vllm.transformers_utils.processor import get_processor
|
||||
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, get_vit_attn_backend,
|
||||
is_pp_missing_parameter,
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
@ -188,13 +187,11 @@ class MultiHeadDotProductAttention(nn.Module):
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
# Detect attention implementation.
|
||||
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
|
||||
if self.attn_backend not in {
|
||||
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS
|
||||
}:
|
||||
raise RuntimeError(
|
||||
f"Molmo does not support {self.attn_backend} backend now.")
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.attn = MultiHeadAttention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scale,
|
||||
num_kv_heads=self.num_kv_heads)
|
||||
|
||||
def forward(self,
|
||||
inputs_q: torch.Tensor,
|
||||
@ -210,25 +207,8 @@ class MultiHeadDotProductAttention(nn.Module):
|
||||
xq, _ = self.wq(inputs_q)
|
||||
xk, _ = self.wk(inputs_k)
|
||||
xv, _ = self.wv(inputs_v)
|
||||
q_shape = xq.size()[:-1] + (self.num_heads, self.head_dim)
|
||||
kv_shape = xk.size()[:-1] + (self.num_kv_heads, self.head_dim)
|
||||
xq = xq.view(*q_shape)
|
||||
xk = xk.view(*kv_shape)
|
||||
xv = xv.view(*kv_shape)
|
||||
|
||||
if self.attn_backend == _Backend.FLASH_ATTN:
|
||||
from flash_attn import flash_attn_func
|
||||
output = flash_attn_func(xq, xk, xv, dropout_p=0.0, causal=False)
|
||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||
xq, xk, xv = (rearrange(x, "b s h d -> b h s d")
|
||||
for x in (xq, xk, xv))
|
||||
output = F.scaled_dot_product_attention(xq, xk, xv)
|
||||
output = rearrange(output, "b h s d -> b s h d ")
|
||||
elif self.attn_backend == _Backend.XFORMERS:
|
||||
from xformers import ops as xops
|
||||
output = xops.memory_efficient_attention_forward(xq, xk, xv, p=0)
|
||||
|
||||
output = rearrange(output, "b s h d -> b s (h d)").contiguous()
|
||||
output = self.attn(xq, xk, xv)
|
||||
output, _ = self.wo(output)
|
||||
|
||||
return output
|
||||
|
@ -6,12 +6,11 @@ from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
from torch import nn
|
||||
from transformers import SiglipVisionConfig
|
||||
|
||||
from vllm.attention.selector import _Backend
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.inputs import DecoderOnlyInputs, token_inputs
|
||||
@ -29,8 +28,6 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
|
||||
resolve_visual_encoder_outputs)
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
from .utils import get_vit_attn_backend
|
||||
|
||||
|
||||
def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
|
||||
# Since interpolation is applied, the image size need not be divisible
|
||||
@ -291,52 +288,18 @@ class SiglipAttention(nn.Module):
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
||||
|
||||
self.attn_backend = get_vit_attn_backend(support_fa=False)
|
||||
if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}:
|
||||
raise RuntimeError(
|
||||
f"SIGLIP does not support {self.attn_backend} backend now.")
|
||||
self.attn = MultiHeadAttention(self.num_heads_per_partition,
|
||||
self.head_dim, self.scale)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
batch_size, q_len, _ = hidden_states.size()
|
||||
|
||||
qkv_states, _ = self.qkv_proj(hidden_states)
|
||||
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
|
||||
|
||||
query_states = query_states.view(batch_size, q_len,
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim)
|
||||
key_states = key_states.view(batch_size, q_len,
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim)
|
||||
value_states = value_states.view(batch_size, q_len,
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim)
|
||||
|
||||
if self.attn_backend == _Backend.XFORMERS:
|
||||
from xformers import ops as xops
|
||||
|
||||
out = xops.memory_efficient_attention_forward(query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
p=self.dropout,
|
||||
scale=self.scale)
|
||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||
query_states, key_states, value_states = (x.transpose(1, 2)
|
||||
for x in (query_states,
|
||||
key_states,
|
||||
value_states))
|
||||
out = F.scaled_dot_product_attention(query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
dropout_p=self.dropout,
|
||||
scale=self.scale)
|
||||
out = out.transpose(1, 2)
|
||||
|
||||
out = out.view(batch_size, q_len, -1)
|
||||
out = self.attn(query_states, key_states, value_states)
|
||||
attn_output, _ = self.out_proj(out)
|
||||
|
||||
return attn_output, None
|
||||
|
Reference in New Issue
Block a user