[VLM][Model] TP support for ViTs (#7186)

Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Jungho Christopher Cho
2024-08-31 00:19:27 +09:00
committed by GitHub
parent afd39a4511
commit f97be32d1d
9 changed files with 336 additions and 285 deletions

View File

@ -6,8 +6,6 @@ import torch.nn as nn
from huggingface_hub import snapshot_download
from transformers import AutoConfig, AutoModel, CLIPImageProcessor
from vllm.model_executor.models.intern_vit import InternVisionModel
from ..conftest import _ImageAssets, cleanup
pytestmark = pytest.mark.vlm
@ -49,6 +47,7 @@ def run_intern_vit_test(
for pixel_value in pixel_values
]
from vllm.model_executor.models.intern_vit import InternVisionModel
vllm_model = InternVisionModel(config)
vllm_model.load_weights(hf_model.state_dict().items())

View File

@ -6,9 +6,6 @@ import torch
from PIL.Image import Image
from transformers import AutoConfig
from vllm.model_executor.models.internvl import (IMG_CONTEXT, IMG_END,
IMG_START,
image_to_pixel_values)
from vllm.multimodal.utils import rescale_image_size
from vllm.utils import is_cpu
@ -33,35 +30,6 @@ models = [
]
class InternVLProcessor:
"""A simple processor for InternVL2 HF model which misses a processor."""
def __init__(self, hf_runner: HfRunner):
self.num_image_token = hf_runner.model.num_image_token
self.tokenizer = hf_runner.tokenizer
self.dtype = hf_runner.model.dtype
self.config = AutoConfig.from_pretrained(hf_runner.model_name)
self.vision_config = self.config.vision_config
self.use_thumbnail = self.config.use_thumbnail
self.min_num = self.config.min_dynamic_patch
self.max_num = self.config.max_dynamic_patch
self.image_size = self.vision_config.image_size
def __call__(self, text: str, images: Image, **kwargs):
pixel_values = image_to_pixel_values(images, self.image_size,
self.min_num, self.max_num,
self.use_thumbnail).to(self.dtype)
num_patches_list = [pixel_values.shape[0]]
for num_patches in num_patches_list:
context_tokens = IMG_CONTEXT * self.num_image_token * num_patches
image_tokens = IMG_START + context_tokens + IMG_END
text = text.replace('<image>', image_tokens, 1)
prompt = self.tokenizer(text, return_tensors="pt")
prompt.update({"pixel_values": pixel_values})
return prompt
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py
def generate(
self,
@ -127,6 +95,37 @@ def run_test(
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
class InternVLProcessor:
"""A simple processor for InternVL2 which misses a processor."""
def __init__(self, hf_runner: HfRunner):
self.num_image_token = hf_runner.model.num_image_token
self.tokenizer = hf_runner.tokenizer
self.dtype = hf_runner.model.dtype
self.config = AutoConfig.from_pretrained(hf_runner.model_name)
self.vision_config = self.config.vision_config
self.use_thumbnail = self.config.use_thumbnail
self.min_num = self.config.min_dynamic_patch
self.max_num = self.config.max_dynamic_patch
self.image_size = self.vision_config.image_size
def __call__(self, text: str, images: Image, **kwargs):
from vllm.model_executor.models.internvl import (
IMG_CONTEXT, IMG_END, IMG_START, image_to_pixel_values)
pixel_values = image_to_pixel_values(
images, self.image_size, self.min_num, self.max_num,
self.use_thumbnail).to(self.dtype)
num_patches_list = [pixel_values.shape[0]]
for num_patches in num_patches_list:
context_tokens = IMG_CONTEXT * self.num_image_token \
* num_patches
image_tokens = IMG_START + context_tokens + IMG_END
text = text.replace('<image>', image_tokens, 1)
prompt = self.tokenizer(text, return_tensors="pt")
prompt.update({"pixel_values": pixel_values})
return prompt
# max_model_len should be greater than image_feature_size
with vllm_runner(model,
max_model_len=4096,

View File

@ -7,12 +7,14 @@ import torch
import torch.nn as nn
from PIL import Image
from transformers import Blip2VisionConfig, BlipVisionConfig
from transformers.models.blip.modeling_blip import BlipAttention
from xformers import ops as xops
from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import LLMInputs
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal.utils import (cached_get_tokenizer,
@ -154,6 +156,77 @@ class BlipVisionEmbeddings(nn.Module):
return embeddings
class BlipAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
config: BlipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
"embed_dim must be divisible by num_heads "
f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads}).")
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.qkv = QKVParallelLinear(
self.embed_dim,
self.head_dim,
self.num_heads,
bias=config.qkv_bias,
quant_config=quant_config,
)
self.projection = RowParallelLinear(
self.embed_dim,
self.embed_dim,
quant_config=quant_config,
)
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads,
self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
):
"""Input shape: Batch x Time x Channel"""
bsz, tgt_len, _ = hidden_states.size()
qkv_states, _ = self.qkv(hidden_states)
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
query_states = query_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
key_states = key_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
value_states = value_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
out = xops.memory_efficient_attention_forward(query_states,
key_states,
value_states,
p=self.dropout,
scale=self.scale)
out = out.view(bsz, tgt_len, -1)
attn_output, _ = self.projection(out)
return attn_output
class BlipMLP(nn.Module):
def __init__(self,
@ -188,7 +261,7 @@ class BlipEncoderLayer(nn.Module):
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.self_attn = BlipAttention(config)
self.self_attn = BlipAttention(config, quant_config=quant_config)
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.mlp = BlipMLP(config, quant_config=quant_config)
@ -199,7 +272,7 @@ class BlipEncoderLayer(nn.Module):
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
hidden_states = self.self_attn(hidden_states=hidden_states)
hidden_states = residual + hidden_states
residual = hidden_states

View File

@ -714,8 +714,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
use_default_weight_loading = False
if "vision" in name:
if self.vision_model is not None:
# We only do sharding for language model and
# not vision model for now.
# BlipVisionModel does not need sharding
use_default_weight_loading = True
else:
for (param_name, weight_name,

View File

@ -7,12 +7,14 @@ import torch
import torch.nn as nn
from PIL import Image
from transformers import CLIPVisionConfig
from transformers.models.clip.modeling_clip import CLIPAttention
from xformers import ops as xops
from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import LLMInputs
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@ -160,6 +162,78 @@ class CLIPVisionEmbeddings(nn.Module):
return embeddings
class CLIPAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
"embed_dim must be divisible by num_heads "
f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads}).")
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.qkv_proj = QKVParallelLinear(
hidden_size=self.embed_dim,
head_size=self.head_dim,
total_num_heads=self.num_heads,
quant_config=quant_config,
)
self.out_proj = RowParallelLinear(
input_size=self.embed_dim,
output_size=self.embed_dim,
quant_config=quant_config,
)
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads,
self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
):
"""Input shape: Batch x Time x Channel"""
bsz, tgt_len, _ = hidden_states.size()
qkv_states, _ = self.qkv_proj(hidden_states)
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
query_states = query_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
key_states = key_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
value_states = value_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
out = xops.memory_efficient_attention_forward(query_states,
key_states,
value_states,
p=self.dropout,
scale=self.scale)
out = out.view(bsz, tgt_len, -1)
attn_output, _ = self.out_proj(out)
return attn_output
class CLIPMLP(nn.Module):
def __init__(self,
@ -192,7 +266,7 @@ class CLIPEncoderLayer(nn.Module):
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.self_attn = CLIPAttention(config)
self.self_attn = CLIPAttention(config, quant_config=quant_config)
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.mlp = CLIPMLP(config, quant_config=quant_config)
@ -204,7 +278,7 @@ class CLIPEncoderLayer(nn.Module):
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
hidden_states = self.self_attn(hidden_states=hidden_states)
hidden_states = residual + hidden_states
residual = hidden_states
@ -304,7 +378,15 @@ class CLIPVisionModel(nn.Module):
def device(self):
return next(self.parameters()).device
# (TODO) Add prefix argument for filtering out weights to be loaded
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters())
layer_count = len(self.vision_model.encoder.layers)
@ -318,7 +400,16 @@ class CLIPVisionModel(nn.Module):
if layer_idx >= layer_count:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@ -10,10 +10,13 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig
from xformers import ops as xops
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@ -81,7 +84,11 @@ class InternVisionEmbeddings(nn.Module):
class InternAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: PretrainedConfig):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
@ -94,9 +101,13 @@ class InternAttention(nn.Module):
f' {self.num_heads}).')
self.scale = self.head_dim**-0.5
self.qkv = nn.Linear(self.embed_dim,
3 * self.embed_dim,
bias=config.qkv_bias)
self.qkv = QKVParallelLinear(
self.embed_dim,
self.head_dim,
self.num_heads,
bias=config.qkv_bias,
quant_config=quant_config,
)
self.qk_normalization = config.qk_normalization
@ -104,25 +115,40 @@ class InternAttention(nn.Module):
self.q_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
self.k_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
self.proj = nn.Linear(self.embed_dim, self.embed_dim)
self.proj = RowParallelLinear(
self.embed_dim,
self.embed_dim,
quant_config=quant_config,
)
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
qkv, _ = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1)
q = q.view(B, N, self.num_heads_per_partition, self.head_dim)
k = k.view(B, N, self.num_heads_per_partition, self.head_dim)
v = v.view(B, N, self.num_heads_per_partition, self.head_dim)
if self.qk_normalization:
B_, H_, N_, D_ = q.shape
q = self.q_norm.forward_native(q.transpose(1, 2).flatten(
-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
k = self.k_norm.forward_native(k.transpose(1, 2).flatten(
-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
B_, N_, H_, D_ = q.shape
q = self.q_norm.forward_native(q.flatten(-2,
-1)).view(B_, N_, H_, D_)
k = self.k_norm.forward_native(k.flatten(-2,
-1)).view(B_, N_, H_, D_)
x = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
x = x.transpose(1, 2).reshape(B, N, C)
x = xops.memory_efficient_attention_forward(
q,
k,
v,
scale=self.scale,
)
x = x.view(B, N, -1)
x = self.proj(x)
x, _ = self.proj(x)
return x
@ -161,7 +187,7 @@ class InternVisionEncoderLayer(nn.Module):
self.intermediate_size = config.intermediate_size
self.norm_type = config.norm_type
self.attn = InternAttention(config)
self.attn = InternAttention(config, quant_config=quant_config)
self.mlp = InternMLP(config, quant_config=quant_config)
self.norm1 = NORM2FN[self.norm_type](self.embed_dim,
eps=config.layer_norm_eps)

View File

@ -145,7 +145,6 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
self.config = config
self.multimodal_config = multimodal_config
# TODO(ywang96): Port over SiglipVisionModel & TP
self.vision_tower = SiglipVisionModel(config.vision_config)
self.multi_modal_projector = PaliGemmaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
@ -308,34 +307,27 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)
use_default_weight_loading = False
if "vision" in name:
if self.vision_tower is not None:
# We only do sharding for language model and
# not vision model for now.
use_default_weight_loading = True
for (param_name, shard_name, shard_id) in stacked_params_mapping:
if shard_name not in name:
continue
name = name.replace(shard_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for (param_name, shard_name,
shard_id) in stacked_params_mapping:
if shard_name not in name:
continue
name = name.replace(shard_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# lm_head is not used in vllm as it is tied with
# embed_token. To prevent errors, skip loading
# lm_head.weight.
if "lm_head.weight" in name:
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
use_default_weight_loading = True
# lm_head is not used in vllm as it is tied with
# embed_token. To prevent errors, skip loading
# lm_head.weight.
if "lm_head.weight" in name:
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
use_default_weight_loading = True
if use_default_weight_loading:
param = params_dict[name]

View File

@ -71,6 +71,23 @@ CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0,
projection_dim=768)
def _init_img_processor(hf_config: PretrainedConfig):
clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG
layer_idx = hf_config.img_processor.get('layer_idx', -2)
# Initialize the CLIP only up to the required feature layer
if layer_idx < 0:
num_hidden_layers = clip_config.num_hidden_layers + \
layer_idx + 1
else:
num_hidden_layers = layer_idx + 1
img_processor = CLIPVisionModel(
clip_config, num_hidden_layers_override=num_hidden_layers)
return img_processor
class Phi3VImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: Union[torch.Tensor, List[torch.Tensor]]
@ -139,18 +156,8 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
hidden_size = config.n_embd if hasattr(
config, 'n_embd') else config.hidden_size
clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG
self.layer_idx = config.img_processor.get('layer_idx', -2)
self.img_processor = _init_img_processor(config)
# Initialize the CLIP only up to the required feature layer
if self.layer_idx < 0:
num_hidden_layers = clip_config.num_hidden_layers + \
self.layer_idx + 1
else:
num_hidden_layers = self.layer_idx + 1
self.img_processor = CLIPVisionModel(
clip_config, num_hidden_layers_override=num_hidden_layers)
image_dim_out = config.img_processor['image_dim_out']
self.num_img_tokens = config.img_processor['num_img_tokens']
@ -656,23 +663,27 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal):
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
# TODO(ChristopherCho): This is a temporary fix to load
# the vision weights with CLIPVisionModel.load_weights()
vision_weights = []
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
# post_layernorm is not needed in CLIPVisionModel
if "vision_model.post_layernorm" in name:
# Skip loading the img_processor weights since they are
# loaded separately.
if "vision_embed_tokens.img_processor" in name:
vision_weights.append((name, loaded_weight))
continue
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# We only do sharding for language model
# and not vision model for now.
if "vision_embed_tokens" in name and self.vision_embed_tokens:
continue
if weight_name not in name:
continue
param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
@ -686,3 +697,11 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# We use regex to extract the sub-module name
# from "model.vision_embed_tokens.img_processor.*"
vision_weights = [
(re.search(r"vision_embed_tokens\.img_processor\.(.*)",
n).group(1), w) for n, w in vision_weights
]
self.vision_embed_tokens.img_processor.load_weights(vision_weights)

View File

@ -9,12 +9,10 @@ import torch
from PIL import Image
from torch import nn
from transformers import SiglipVisionConfig
from transformers.models.siglip.modeling_siglip import SiglipAttention
from vllm_flash_attn import flash_attn_func
from xformers.ops import memory_efficient_attention
from xformers import ops as xops
from vllm.config import ModelConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import LLMInputs
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -221,9 +219,7 @@ class SiglipVisionEmbeddings(nn.Module):
return embeddings
# NOTE: Not used - kept for later when we TP the ViT
# TODO(ChristopherCho): Implement TP version of Attention
class SiglipTPAttention(nn.Module):
class SiglipAttention(nn.Module):
def __init__(
self,
@ -233,38 +229,30 @@ class SiglipTPAttention(nn.Module):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = config.num_attention_heads
if self.total_num_heads % tp_size != 0:
raise ValueError(
f"Number of attention heads ({self.total_num_heads}) "
"must be divisible by the tensor model parallel size"
f" ({tp_size}).")
self.num_heads = self.total_num_heads // tp_size
self.head_dim = self.embed_dim // self.total_num_heads
if self.head_dim * self.total_num_heads != self.embed_dim:
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(f"embed_dim must be divisible by num_heads (got "
"`embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads}).")
self.qkv_size = self.num_heads * self.head_dim
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.qkv_proj = QKVParallelLinear(
hidden_size=self.embed_dim,
head_size=self.head_dim,
total_num_heads=self.total_num_heads,
total_num_heads=self.num_heads,
quant_config=quant_config,
)
self.out_proj = RowParallelLinear(
input_size=self.embed_dim,
output_size=self.embed_dim,
quant_config=quant_config,
)
self.attn_fn = self._basic_attention_forward
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
def forward(
self,
@ -274,163 +262,29 @@ class SiglipTPAttention(nn.Module):
batch_size, q_len, _ = hidden_states.size()
qkv_states, _ = self.qkv_proj(hidden_states)
query_states, key_states, value_states = qkv_states.split(
[self.qkv_size] * 3, dim=-1)
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
attn_output = self.attn_fn(
q=query_states,
k=key_states,
v=value_states,
batch_size=batch_size,
q_len=q_len,
)
query_states = query_states.view(batch_size, q_len,
self.num_heads_per_partition,
self.head_dim)
key_states = key_states.view(batch_size, q_len,
self.num_heads_per_partition,
self.head_dim)
value_states = value_states.view(batch_size, q_len,
self.num_heads_per_partition,
self.head_dim)
attn_output, _ = self.out_proj(attn_output)
return attn_output
def _basic_attention_forward(self, q, k, v, batch_size, q_len):
q = q.view(batch_size, q_len, self.num_heads,
self.head_dim).transpose(1, 2)
k = k.view(batch_size, q_len, self.num_heads,
self.head_dim).transpose(1, 2)
v = v.view(batch_size, q_len, self.num_heads,
self.head_dim).transpose(1, 2)
k_v_seq_len = k.shape[-2]
attn_weights = torch.matmul(q, k.transpose(2, 3)) * self.scale
if attn_weights.size() != (
batch_size,
self.num_heads,
q_len,
k_v_seq_len,
):
raise ValueError(
"Attention weights should be of size "
f"{(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
f" {attn_weights.size()}")
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights,
dim=-1,
dtype=torch.float32).to(q.dtype)
attn_weights = nn.functional.dropout(attn_weights,
p=self.dropout,
training=self.training)
attn_output = torch.matmul(attn_weights, v)
if attn_output.size() != (
batch_size,
self.num_heads,
q_len,
self.head_dim,
):
raise ValueError(
"`attn_output` should be of size "
f"{(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}")
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
out = xops.memory_efficient_attention_forward(query_states,
key_states,
value_states,
p=self.dropout,
scale=self.scale)
out = out.view(batch_size, q_len, -1)
attn_output, _ = self.out_proj(out)
return attn_output
# NOTE: Not used - kept for later when we TP the ViT
# TODO(ChristopherCho): flash_attn_func is not working properly.
# It constantly throws a CUDA error.
class SiglipFlashAttention2(SiglipTPAttention):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.attn_fn = self._flash_attention_forward
# Ported from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L449
# and https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/modeling_flash_attention_utils.py#L133
def _flash_attention_forward(self, q, k, v, batch_size, q_len, *args,
**kwargs):
"""Implements the multihead softmax attention.
Arguments
---------
q, k, v: The tensor containing the
query, key, and value. (B, S, H, D)
"""
q = q.view(batch_size, q_len, self.num_heads, self.head_dim)
k = k.view(batch_size, q_len, self.num_heads, self.head_dim)
v = v.view(batch_size, q_len, self.num_heads, self.head_dim)
attn_output = flash_attn_func(
q,
k,
v,
dropout_p=self.dropout,
causal=False,
)
attn_output = attn_output.reshape(batch_size, q_len,
self.embed_dim).contiguous()
return attn_output
# NOTE: Not used - kept for later when we TP the ViT
class SiglipSdpaAttention(SiglipTPAttention):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.is_causal = False
self.attn_fn = self._sdpa_attention_forward
def _sdpa_attention_forward(self, q, k, v, batch_size, q_len):
q = q.view(batch_size, q_len, self.num_heads,
self.head_dim).transpose(1, 2)
k = k.view(batch_size, q_len, self.num_heads,
self.head_dim).transpose(1, 2)
v = v.view(batch_size, q_len, self.num_heads,
self.head_dim).transpose(1, 2)
attn_output = torch.nn.functional.scaled_dot_product_attention(
q, k, v, dropout_p=self.dropout, is_causal=False, scale=self.scale)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, q_len, self.embed_dim)
return attn_output
# NOTE: Not used - kept for later when we TP the ViT
class SiglipxFormersAttention(SiglipTPAttention):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.attn_fn = self._xformers_attention_forward
def _xformers_attention_forward(self, q, k, v, batch_size, q_len):
q = q.view(batch_size, q_len, self.num_heads, self.head_dim)
k = k.view(batch_size, q_len, self.num_heads, self.head_dim)
v = v.view(batch_size, q_len, self.num_heads, self.head_dim)
attn_output = memory_efficient_attention(q,
k,
v,
p=0.0,
scale=self.scale)
attn_output = attn_output.reshape(batch_size, q_len,
self.embed_dim).contiguous()
return attn_output
# NOTE: Not used - kept for later when we TP the ViT
SIGLIP_ATTENTION_CLASSES = {
"eager": SiglipTPAttention,
"flash_attention_2": SiglipFlashAttention2,
"sdpa": SiglipSdpaAttention,
"xformers": SiglipxFormersAttention,
}
class SiglipMLP(nn.Module):
def __init__(
@ -473,8 +327,7 @@ class SiglipEncoderLayer(nn.Module):
super().__init__()
self.embed_dim = config.hidden_size
# TODO(ChristopherCho): use TP'ed Attention block
self.self_attn = SiglipAttention(config)
self.self_attn = SiglipAttention(config, quant_config=quant_config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps)
self.mlp = SiglipMLP(
@ -491,7 +344,7 @@ class SiglipEncoderLayer(nn.Module):
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
hidden_states = self.self_attn(hidden_states=hidden_states)
hidden_states = residual + hidden_states
residual = hidden_states