[Model] Consolidate ViTs attention implementation without mask (#10893)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py
2024-12-05 02:11:08 +08:00
committed by GitHub
parent 01d079fd8e
commit 10398b4706
9 changed files with 107 additions and 224 deletions

View File

@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from vllm.attention import AttentionMetadata, AttentionType
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
@ -168,6 +169,68 @@ class Attention(nn.Module):
return s
class MultiHeadAttention(nn.Module):
"""Multi-headed attention without any cache, used for ViT."""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
):
super().__init__()
self.num_heads = num_heads
self.head_size = head_size
self.scale = scale
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
dtype = torch.get_default_dtype()
attn_backend = get_attn_backend(head_size,
dtype,
kv_cache_dtype=None,
block_size=16,
is_attention_free=False)
if attn_backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
attn_backend = _Backend.XFORMERS
self.attn_backend = attn_backend if attn_backend in {
_Backend.TORCH_SDPA, _Backend.XFORMERS
} else _Backend.TORCH_SDPA
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
) -> torch.Tensor:
"""Input shape: batch_size x seq_len x hidden_size"""
# TODO(Isotr0py): Use existing backend implementations and support FA2
bsz, q_len, _ = query.size()
kv_len = key.size(1)
query = query.view(bsz, q_len, self.num_heads, self.head_size)
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)
if self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops
out = xops.memory_efficient_attention_forward(query,
key,
value,
scale=self.scale)
elif self.attn_backend == _Backend.TORCH_SDPA:
query, key, value = (x.transpose(1, 2)
for x in (query, key, value))
out = F.scaled_dot_product_attention(query,
key,
value,
scale=self.scale)
out = out.transpose(1, 2)
return out.view(bsz, q_len, -1)
def unified_attention(
query: torch.Tensor,
key: torch.Tensor,

View File

@ -4,11 +4,10 @@ from typing import Iterable, Optional, Set, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from transformers import Blip2VisionConfig, BlipVisionConfig
from vllm.attention.selector import _Backend
from vllm.attention.layer import MultiHeadAttention
from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import DecoderOnlyInputs, token_inputs
@ -22,8 +21,6 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import SequenceData
from .utils import get_vit_attn_backend
def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
assert image_size % patch_size == 0
@ -205,11 +202,8 @@ class BlipAttention(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
# Detect attention implementation.
self.attn_backend = get_vit_attn_backend(support_fa=False)
if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}:
raise RuntimeError(
f"BLIP does not support {self.attn_backend} backend now.")
self.attn = MultiHeadAttention(self.num_heads_per_partition,
self.head_dim, self.scale)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads,
@ -220,41 +214,10 @@ class BlipAttention(nn.Module):
hidden_states: torch.Tensor,
):
"""Input shape: Batch x Time x Channel"""
bsz, tgt_len, _ = hidden_states.size()
qkv_states, _ = self.qkv(hidden_states)
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
query_states = query_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
key_states = key_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
value_states = value_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
if self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops
out = xops.memory_efficient_attention_forward(query_states,
key_states,
value_states,
p=self.dropout,
scale=self.scale)
elif self.attn_backend == _Backend.TORCH_SDPA:
query_states, key_states, value_states = (x.transpose(1, 2)
for x in (query_states,
key_states,
value_states))
out = F.scaled_dot_product_attention(query_states,
key_states,
value_states,
dropout_p=self.dropout,
scale=self.scale)
out = out.transpose(1, 2)
out = out.view(bsz, tgt_len, -1)
out = self.attn(query_states, key_states, value_states)
attn_output, _ = self.projection(out)
return attn_output, None

View File

@ -5,11 +5,10 @@ from typing import Iterable, List, Optional, Set, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from transformers import CLIPVisionConfig
from vllm.attention.selector import _Backend
from vllm.attention.layer import MultiHeadAttention
from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import DecoderOnlyInputs, token_inputs
@ -25,8 +24,6 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
resolve_visual_encoder_outputs)
from vllm.sequence import SequenceData
from .utils import get_vit_attn_backend
def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
assert image_size % patch_size == 0
@ -235,11 +232,8 @@ class CLIPAttention(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
# Detect attention implementation.
self.attn_backend = get_vit_attn_backend(support_fa=False)
if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}:
raise RuntimeError(
f"CLIP does not support {self.attn_backend} backend now.")
self.attn = MultiHeadAttention(self.num_heads_per_partition,
self.head_dim, self.scale)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads,
@ -250,42 +244,10 @@ class CLIPAttention(nn.Module):
hidden_states: torch.Tensor,
):
"""Input shape: Batch x Time x Channel"""
bsz, tgt_len, _ = hidden_states.size()
qkv_states, _ = self.qkv_proj(hidden_states)
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
query_states = query_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
key_states = key_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
value_states = value_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
if self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops
out = xops.memory_efficient_attention_forward(query_states,
key_states,
value_states,
p=self.dropout,
scale=self.scale)
elif self.attn_backend == _Backend.TORCH_SDPA:
query_states, key_states, value_states = (x.transpose(1, 2)
for x in (query_states,
key_states,
value_states))
out = F.scaled_dot_product_attention(query_states,
key_states,
value_states,
dropout_p=self.dropout,
scale=self.scale)
out = out.transpose(1, 2)
out = out.view(bsz, tgt_len, -1)
out = self.attn(query_states, key_states, value_states)
attn_output, _ = self.out_proj(out)
return attn_output, None

View File

@ -8,6 +8,7 @@ import torch
from torch import nn
from torch.nn import LayerNorm
from vllm.attention.layer import MultiHeadAttention
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -77,27 +78,16 @@ class Attention(nn.Module):
quant_config=quant_config,
)
self.attn = MultiHeadAttention(self.num_heads_per_rank, self.head_dim,
self.scale)
self.output_dropout = torch.nn.Dropout(config.dropout_prob)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, L, _ = x.shape
qkv, _ = self.query_key_value(x) # B, L, 3 * H * D
q, k, v = qkv.chunk(3, dim=-1)
q = q.reshape(B, L, self.num_heads_per_rank,
self.head_dim).permute(0, 2, 1, 3) # B, H, L, D
k = k.reshape(B, L, self.num_heads_per_rank,
self.head_dim).permute(0, 2, 1, 3) # B, H, L, D
v = v.reshape(B, L, self.num_heads_per_rank,
self.head_dim).permute(0, 2, 1, 3) # B, H, L, D
out = torch.nn.functional.scaled_dot_product_attention(q,
k,
v,
attn_mask=None,
dropout_p=0.,
is_causal=False)
output, _ = self.dense(out.transpose(1, 2).view(B, L, -1))
out = self.attn(q, k, v)
output, _ = self.dense(out)
output = self.output_dropout(output)
return output

View File

@ -21,8 +21,8 @@ import torch
from torch import nn
from transformers.models.idefics2.configuration_idefics2 import (
Idefics2Config, Idefics2VisionConfig)
from xformers import ops as xops
from vllm.attention.layer import MultiHeadAttention
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -141,35 +141,18 @@ class Idefics2VisionAttention(nn.Module):
)
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
self.is_causal = False
self.attn = MultiHeadAttention(self.num_heads_per_partition,
self.head_dim, self.scale)
def forward(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
batch_size, q_len, _ = hidden_states.size()
qkv, _ = self.qkv_proj(
hidden_states
) # batch_size, q_len, 3 * num_heads_per_partition * head_dim
query_states, key_states, value_states = qkv.chunk(3, dim=-1)
query_states = query_states.view(batch_size, q_len,
self.num_heads_per_partition,
self.head_dim)
key_states = key_states.view(batch_size, q_len,
self.num_heads_per_partition,
self.head_dim)
value_states = value_states.view(batch_size, q_len,
self.num_heads_per_partition,
self.head_dim)
# see: https://facebookresearch.github.io/xformers/components/ops.html
out = xops.memory_efficient_attention_forward(
query_states,
key_states,
value_states,
p=self.dropout,
scale=self.scale,
)
out = out.view(batch_size, q_len, -1)
out = self.attn(query_states, key_states, value_states)
attn_output, _ = self.out_proj(out)
return attn_output

View File

@ -12,7 +12,7 @@ import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig
from vllm.attention.selector import _Backend
from vllm.attention.layer import MultiHeadAttention
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
@ -25,8 +25,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from .utils import get_vit_attn_backend
NORM2FN = {
'rms_norm': RMSNorm,
'layer_norm': nn.LayerNorm,
@ -183,10 +181,8 @@ class InternParallelAttention(nn.Module):
prefix=f"{prefix}.proj",
)
self.attn_backend = get_vit_attn_backend(support_fa=False)
if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}:
raise RuntimeError(
f"InternViT does not support {self.attn_backend} backend now.")
self.attn = MultiHeadAttention(self.num_heads_per_partition,
self.head_dim, self.scale)
def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
if self.tp_size > 1:
@ -209,23 +205,7 @@ class InternParallelAttention(nn.Module):
if self.qk_normalization:
q, k = self._apply_qk_norm(q, k)
q = q.view(B, N, self.num_heads_per_partition, self.head_dim)
k = k.view(B, N, self.num_heads_per_partition, self.head_dim)
v = v.view(B, N, self.num_heads_per_partition, self.head_dim)
if self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops
out = xops.memory_efficient_attention_forward(q,
k,
v,
scale=self.scale)
elif self.attn_backend == _Backend.TORCH_SDPA:
q, k, v = (x.transpose(1, 2) for x in (q, k, v))
out = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
out = out.transpose(1, 2)
out = out.view(B, N, -1)
out = self.attn(q, k, v)
out, _ = self.proj(out)
return out

View File

@ -482,6 +482,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
self.mlp1 = self._init_mlp1(config)
self.img_context_token_id = None
self.visual_token_mask = None
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
@ -635,13 +636,12 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
return image_embeds
def _get_visual_token_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
def _set_visual_token_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
if self.is_mono:
visual_token_mask = (
self.visual_token_mask = (
input_ids == self.img_context_token_id).reshape(-1, 1)
else:
visual_token_mask = None
return visual_token_mask
self.visual_token_mask = None
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input = self._parse_and_validate_image_input(**kwargs)
@ -658,6 +658,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
assert self.img_context_token_id is not None
self._set_visual_token_mask(input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.img_context_token_id)
@ -674,7 +675,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
**kwargs: object,
) -> Union[SamplerOutput, IntermediateTensors]:
visual_token_mask = None
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
@ -695,16 +695,15 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
"intermediate_tensors": intermediate_tensors,
"inputs_embeds": inputs_embeds,
}
if self.img_context_token_id is not None:
visual_token_mask = self._get_visual_token_mask(input_ids)
# We always overwrite it back to None after computing visual token
# mask so that this doesn't need to depend on encoder output
if self.visual_token_mask is not None:
# overwrite visual_token_mask and img_context_token_id back to None,
# so that this doesn't need to depend on encoder output
forward_kwargs.update(
{"visual_token_mask": self.visual_token_mask})
self.visual_token_mask = None
self.img_context_token_id = None
if self.is_mono:
forward_kwargs.update({"visual_token_mask": visual_token_mask})
hidden_states = self.language_model.model(**forward_kwargs)
return hidden_states

View File

@ -13,6 +13,7 @@ from torch.nn import functional as F
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.attention.layer import MultiHeadAttention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
@ -38,14 +39,12 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.platforms import _Backend
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
from vllm.transformers_utils.processor import get_processor
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, get_vit_attn_backend,
is_pp_missing_parameter,
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
@ -188,13 +187,11 @@ class MultiHeadDotProductAttention(nn.Module):
quant_config=quant_config,
)
# Detect attention implementation.
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
if self.attn_backend not in {
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS
}:
raise RuntimeError(
f"Molmo does not support {self.attn_backend} backend now.")
self.scale = self.head_dim**-0.5
self.attn = MultiHeadAttention(self.num_heads,
self.head_dim,
self.scale,
num_kv_heads=self.num_kv_heads)
def forward(self,
inputs_q: torch.Tensor,
@ -210,25 +207,8 @@ class MultiHeadDotProductAttention(nn.Module):
xq, _ = self.wq(inputs_q)
xk, _ = self.wk(inputs_k)
xv, _ = self.wv(inputs_v)
q_shape = xq.size()[:-1] + (self.num_heads, self.head_dim)
kv_shape = xk.size()[:-1] + (self.num_kv_heads, self.head_dim)
xq = xq.view(*q_shape)
xk = xk.view(*kv_shape)
xv = xv.view(*kv_shape)
if self.attn_backend == _Backend.FLASH_ATTN:
from flash_attn import flash_attn_func
output = flash_attn_func(xq, xk, xv, dropout_p=0.0, causal=False)
elif self.attn_backend == _Backend.TORCH_SDPA:
xq, xk, xv = (rearrange(x, "b s h d -> b h s d")
for x in (xq, xk, xv))
output = F.scaled_dot_product_attention(xq, xk, xv)
output = rearrange(output, "b h s d -> b s h d ")
elif self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops
output = xops.memory_efficient_attention_forward(xq, xk, xv, p=0)
output = rearrange(output, "b s h d -> b s (h d)").contiguous()
output = self.attn(xq, xk, xv)
output, _ = self.wo(output)
return output

View File

@ -6,12 +6,11 @@ from typing import Iterable, List, Optional, Set, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torch import nn
from transformers import SiglipVisionConfig
from vllm.attention.selector import _Backend
from vllm.attention.layer import MultiHeadAttention
from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import DecoderOnlyInputs, token_inputs
@ -29,8 +28,6 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
resolve_visual_encoder_outputs)
from vllm.sequence import SequenceData
from .utils import get_vit_attn_backend
def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
# Since interpolation is applied, the image size need not be divisible
@ -291,52 +288,18 @@ class SiglipAttention(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
self.attn_backend = get_vit_attn_backend(support_fa=False)
if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}:
raise RuntimeError(
f"SIGLIP does not support {self.attn_backend} backend now.")
self.attn = MultiHeadAttention(self.num_heads_per_partition,
self.head_dim, self.scale)
def forward(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
"""Input shape: Batch x Time x Channel"""
batch_size, q_len, _ = hidden_states.size()
qkv_states, _ = self.qkv_proj(hidden_states)
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
query_states = query_states.view(batch_size, q_len,
self.num_heads_per_partition,
self.head_dim)
key_states = key_states.view(batch_size, q_len,
self.num_heads_per_partition,
self.head_dim)
value_states = value_states.view(batch_size, q_len,
self.num_heads_per_partition,
self.head_dim)
if self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops
out = xops.memory_efficient_attention_forward(query_states,
key_states,
value_states,
p=self.dropout,
scale=self.scale)
elif self.attn_backend == _Backend.TORCH_SDPA:
query_states, key_states, value_states = (x.transpose(1, 2)
for x in (query_states,
key_states,
value_states))
out = F.scaled_dot_product_attention(query_states,
key_states,
value_states,
dropout_p=self.dropout,
scale=self.scale)
out = out.transpose(1, 2)
out = out.view(batch_size, q_len, -1)
out = self.attn(query_states, key_states, value_states)
attn_output, _ = self.out_proj(out)
return attn_output, None