mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[VLM][Model] TP support for ViTs (#7186)
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com> Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
committed by
GitHub
parent
afd39a4511
commit
f97be32d1d
@ -6,8 +6,6 @@ import torch.nn as nn
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import AutoConfig, AutoModel, CLIPImageProcessor
|
||||
|
||||
from vllm.model_executor.models.intern_vit import InternVisionModel
|
||||
|
||||
from ..conftest import _ImageAssets, cleanup
|
||||
|
||||
pytestmark = pytest.mark.vlm
|
||||
@ -49,6 +47,7 @@ def run_intern_vit_test(
|
||||
for pixel_value in pixel_values
|
||||
]
|
||||
|
||||
from vllm.model_executor.models.intern_vit import InternVisionModel
|
||||
vllm_model = InternVisionModel(config)
|
||||
vllm_model.load_weights(hf_model.state_dict().items())
|
||||
|
||||
|
@ -6,9 +6,6 @@ import torch
|
||||
from PIL.Image import Image
|
||||
from transformers import AutoConfig
|
||||
|
||||
from vllm.model_executor.models.internvl import (IMG_CONTEXT, IMG_END,
|
||||
IMG_START,
|
||||
image_to_pixel_values)
|
||||
from vllm.multimodal.utils import rescale_image_size
|
||||
from vllm.utils import is_cpu
|
||||
|
||||
@ -33,35 +30,6 @@ models = [
|
||||
]
|
||||
|
||||
|
||||
class InternVLProcessor:
|
||||
"""A simple processor for InternVL2 HF model which misses a processor."""
|
||||
|
||||
def __init__(self, hf_runner: HfRunner):
|
||||
self.num_image_token = hf_runner.model.num_image_token
|
||||
self.tokenizer = hf_runner.tokenizer
|
||||
self.dtype = hf_runner.model.dtype
|
||||
|
||||
self.config = AutoConfig.from_pretrained(hf_runner.model_name)
|
||||
self.vision_config = self.config.vision_config
|
||||
self.use_thumbnail = self.config.use_thumbnail
|
||||
self.min_num = self.config.min_dynamic_patch
|
||||
self.max_num = self.config.max_dynamic_patch
|
||||
self.image_size = self.vision_config.image_size
|
||||
|
||||
def __call__(self, text: str, images: Image, **kwargs):
|
||||
pixel_values = image_to_pixel_values(images, self.image_size,
|
||||
self.min_num, self.max_num,
|
||||
self.use_thumbnail).to(self.dtype)
|
||||
num_patches_list = [pixel_values.shape[0]]
|
||||
for num_patches in num_patches_list:
|
||||
context_tokens = IMG_CONTEXT * self.num_image_token * num_patches
|
||||
image_tokens = IMG_START + context_tokens + IMG_END
|
||||
text = text.replace('<image>', image_tokens, 1)
|
||||
prompt = self.tokenizer(text, return_tensors="pt")
|
||||
prompt.update({"pixel_values": pixel_values})
|
||||
return prompt
|
||||
|
||||
|
||||
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py
|
||||
def generate(
|
||||
self,
|
||||
@ -127,6 +95,37 @@ def run_test(
|
||||
# if we run HF first, the cuda initialization will be done and it
|
||||
# will hurt multiprocessing backend with fork method (the default method).
|
||||
|
||||
class InternVLProcessor:
|
||||
"""A simple processor for InternVL2 which misses a processor."""
|
||||
|
||||
def __init__(self, hf_runner: HfRunner):
|
||||
self.num_image_token = hf_runner.model.num_image_token
|
||||
self.tokenizer = hf_runner.tokenizer
|
||||
self.dtype = hf_runner.model.dtype
|
||||
|
||||
self.config = AutoConfig.from_pretrained(hf_runner.model_name)
|
||||
self.vision_config = self.config.vision_config
|
||||
self.use_thumbnail = self.config.use_thumbnail
|
||||
self.min_num = self.config.min_dynamic_patch
|
||||
self.max_num = self.config.max_dynamic_patch
|
||||
self.image_size = self.vision_config.image_size
|
||||
|
||||
def __call__(self, text: str, images: Image, **kwargs):
|
||||
from vllm.model_executor.models.internvl import (
|
||||
IMG_CONTEXT, IMG_END, IMG_START, image_to_pixel_values)
|
||||
pixel_values = image_to_pixel_values(
|
||||
images, self.image_size, self.min_num, self.max_num,
|
||||
self.use_thumbnail).to(self.dtype)
|
||||
num_patches_list = [pixel_values.shape[0]]
|
||||
for num_patches in num_patches_list:
|
||||
context_tokens = IMG_CONTEXT * self.num_image_token \
|
||||
* num_patches
|
||||
image_tokens = IMG_START + context_tokens + IMG_END
|
||||
text = text.replace('<image>', image_tokens, 1)
|
||||
prompt = self.tokenizer(text, return_tensors="pt")
|
||||
prompt.update({"pixel_values": pixel_values})
|
||||
return prompt
|
||||
|
||||
# max_model_len should be greater than image_feature_size
|
||||
with vllm_runner(model,
|
||||
max_model_len=4096,
|
||||
|
@ -7,12 +7,14 @@ import torch
|
||||
import torch.nn as nn
|
||||
from PIL import Image
|
||||
from transformers import Blip2VisionConfig, BlipVisionConfig
|
||||
from transformers.models.blip.modeling_blip import BlipAttention
|
||||
from xformers import ops as xops
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.inputs import LLMInputs
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.multimodal.utils import (cached_get_tokenizer,
|
||||
@ -154,6 +156,77 @@ class BlipVisionEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
class BlipAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: BlipVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
if self.head_dim * self.num_heads != self.embed_dim:
|
||||
raise ValueError(
|
||||
"embed_dim must be divisible by num_heads "
|
||||
f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||
f" {self.num_heads}).")
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.dropout = config.attention_dropout
|
||||
|
||||
self.qkv = QKVParallelLinear(
|
||||
self.embed_dim,
|
||||
self.head_dim,
|
||||
self.num_heads,
|
||||
bias=config.qkv_bias,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.projection = RowParallelLinear(
|
||||
self.embed_dim,
|
||||
self.embed_dim,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads,
|
||||
self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
):
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
qkv_states, _ = self.qkv(hidden_states)
|
||||
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
|
||||
query_states = query_states.view(bsz, tgt_len,
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim)
|
||||
key_states = key_states.view(bsz, tgt_len,
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim)
|
||||
value_states = value_states.view(bsz, tgt_len,
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim)
|
||||
|
||||
out = xops.memory_efficient_attention_forward(query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
p=self.dropout,
|
||||
scale=self.scale)
|
||||
out = out.view(bsz, tgt_len, -1)
|
||||
attn_output, _ = self.projection(out)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
class BlipMLP(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
@ -188,7 +261,7 @@ class BlipEncoderLayer(nn.Module):
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__()
|
||||
|
||||
self.self_attn = BlipAttention(config)
|
||||
self.self_attn = BlipAttention(config, quant_config=quant_config)
|
||||
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
self.mlp = BlipMLP(config, quant_config=quant_config)
|
||||
@ -199,7 +272,7 @@ class BlipEncoderLayer(nn.Module):
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
|
||||
hidden_states = self.self_attn(hidden_states=hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
|
@ -714,8 +714,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
use_default_weight_loading = False
|
||||
if "vision" in name:
|
||||
if self.vision_model is not None:
|
||||
# We only do sharding for language model and
|
||||
# not vision model for now.
|
||||
# BlipVisionModel does not need sharding
|
||||
use_default_weight_loading = True
|
||||
else:
|
||||
for (param_name, weight_name,
|
||||
|
@ -7,12 +7,14 @@ import torch
|
||||
import torch.nn as nn
|
||||
from PIL import Image
|
||||
from transformers import CLIPVisionConfig
|
||||
from transformers.models.clip.modeling_clip import CLIPAttention
|
||||
from xformers import ops as xops
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.inputs import LLMInputs
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
@ -160,6 +162,78 @@ class CLIPVisionEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
class CLIPAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: CLIPVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
if self.head_dim * self.num_heads != self.embed_dim:
|
||||
raise ValueError(
|
||||
"embed_dim must be divisible by num_heads "
|
||||
f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||
f" {self.num_heads}).")
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.dropout = config.attention_dropout
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size=self.embed_dim,
|
||||
head_size=self.head_dim,
|
||||
total_num_heads=self.num_heads,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.out_proj = RowParallelLinear(
|
||||
input_size=self.embed_dim,
|
||||
output_size=self.embed_dim,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads,
|
||||
self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
):
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
qkv_states, _ = self.qkv_proj(hidden_states)
|
||||
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
|
||||
|
||||
query_states = query_states.view(bsz, tgt_len,
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim)
|
||||
key_states = key_states.view(bsz, tgt_len,
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim)
|
||||
value_states = value_states.view(bsz, tgt_len,
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim)
|
||||
|
||||
out = xops.memory_efficient_attention_forward(query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
p=self.dropout,
|
||||
scale=self.scale)
|
||||
out = out.view(bsz, tgt_len, -1)
|
||||
attn_output, _ = self.out_proj(out)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
class CLIPMLP(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
@ -192,7 +266,7 @@ class CLIPEncoderLayer(nn.Module):
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__()
|
||||
|
||||
self.self_attn = CLIPAttention(config)
|
||||
self.self_attn = CLIPAttention(config, quant_config=quant_config)
|
||||
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
self.mlp = CLIPMLP(config, quant_config=quant_config)
|
||||
@ -204,7 +278,7 @@ class CLIPEncoderLayer(nn.Module):
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
|
||||
hidden_states = self.self_attn(hidden_states=hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
@ -304,7 +378,15 @@ class CLIPVisionModel(nn.Module):
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
# (TODO) Add prefix argument for filtering out weights to be loaded
|
||||
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
layer_count = len(self.vision_model.encoder.layers)
|
||||
|
||||
@ -318,7 +400,16 @@ class CLIPVisionModel(nn.Module):
|
||||
if layer_idx >= layer_count:
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
param = params_dict[name.replace(weight_name, param_name)]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
@ -10,10 +10,13 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import PretrainedConfig
|
||||
from xformers import ops as xops
|
||||
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
@ -81,7 +84,11 @@ class InternVisionEmbeddings(nn.Module):
|
||||
class InternAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
@ -94,9 +101,13 @@ class InternAttention(nn.Module):
|
||||
f' {self.num_heads}).')
|
||||
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.qkv = nn.Linear(self.embed_dim,
|
||||
3 * self.embed_dim,
|
||||
bias=config.qkv_bias)
|
||||
self.qkv = QKVParallelLinear(
|
||||
self.embed_dim,
|
||||
self.head_dim,
|
||||
self.num_heads,
|
||||
bias=config.qkv_bias,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.qk_normalization = config.qk_normalization
|
||||
|
||||
@ -104,25 +115,40 @@ class InternAttention(nn.Module):
|
||||
self.q_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
self.k_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
self.proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.proj = RowParallelLinear(
|
||||
self.embed_dim,
|
||||
self.embed_dim,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
|
||||
C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0)
|
||||
qkv, _ = self.qkv(x)
|
||||
q, k, v = qkv.chunk(3, dim=-1)
|
||||
|
||||
q = q.view(B, N, self.num_heads_per_partition, self.head_dim)
|
||||
k = k.view(B, N, self.num_heads_per_partition, self.head_dim)
|
||||
v = v.view(B, N, self.num_heads_per_partition, self.head_dim)
|
||||
|
||||
if self.qk_normalization:
|
||||
B_, H_, N_, D_ = q.shape
|
||||
q = self.q_norm.forward_native(q.transpose(1, 2).flatten(
|
||||
-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
|
||||
k = self.k_norm.forward_native(k.transpose(1, 2).flatten(
|
||||
-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
|
||||
B_, N_, H_, D_ = q.shape
|
||||
q = self.q_norm.forward_native(q.flatten(-2,
|
||||
-1)).view(B_, N_, H_, D_)
|
||||
k = self.k_norm.forward_native(k.flatten(-2,
|
||||
-1)).view(B_, N_, H_, D_)
|
||||
|
||||
x = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
|
||||
x = x.transpose(1, 2).reshape(B, N, C)
|
||||
x = xops.memory_efficient_attention_forward(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
scale=self.scale,
|
||||
)
|
||||
x = x.view(B, N, -1)
|
||||
|
||||
x = self.proj(x)
|
||||
x, _ = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
@ -161,7 +187,7 @@ class InternVisionEncoderLayer(nn.Module):
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.norm_type = config.norm_type
|
||||
|
||||
self.attn = InternAttention(config)
|
||||
self.attn = InternAttention(config, quant_config=quant_config)
|
||||
self.mlp = InternMLP(config, quant_config=quant_config)
|
||||
self.norm1 = NORM2FN[self.norm_type](self.embed_dim,
|
||||
eps=config.layer_norm_eps)
|
||||
|
@ -145,7 +145,6 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
# TODO(ywang96): Port over SiglipVisionModel & TP
|
||||
self.vision_tower = SiglipVisionModel(config.vision_config)
|
||||
self.multi_modal_projector = PaliGemmaMultiModalProjector(
|
||||
vision_hidden_size=config.vision_config.hidden_size,
|
||||
@ -308,34 +307,27 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
if key_to_modify in name:
|
||||
name = name.replace(key_to_modify, new_key)
|
||||
use_default_weight_loading = False
|
||||
if "vision" in name:
|
||||
if self.vision_tower is not None:
|
||||
# We only do sharding for language model and
|
||||
# not vision model for now.
|
||||
use_default_weight_loading = True
|
||||
for (param_name, shard_name, shard_id) in stacked_params_mapping:
|
||||
if shard_name not in name:
|
||||
continue
|
||||
name = name.replace(shard_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
for (param_name, shard_name,
|
||||
shard_id) in stacked_params_mapping:
|
||||
if shard_name not in name:
|
||||
continue
|
||||
name = name.replace(shard_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# lm_head is not used in vllm as it is tied with
|
||||
# embed_token. To prevent errors, skip loading
|
||||
# lm_head.weight.
|
||||
if "lm_head.weight" in name:
|
||||
continue
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
use_default_weight_loading = True
|
||||
# lm_head is not used in vllm as it is tied with
|
||||
# embed_token. To prevent errors, skip loading
|
||||
# lm_head.weight.
|
||||
if "lm_head.weight" in name:
|
||||
continue
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
use_default_weight_loading = True
|
||||
|
||||
if use_default_weight_loading:
|
||||
param = params_dict[name]
|
||||
|
@ -71,6 +71,23 @@ CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0,
|
||||
projection_dim=768)
|
||||
|
||||
|
||||
def _init_img_processor(hf_config: PretrainedConfig):
|
||||
clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG
|
||||
layer_idx = hf_config.img_processor.get('layer_idx', -2)
|
||||
|
||||
# Initialize the CLIP only up to the required feature layer
|
||||
if layer_idx < 0:
|
||||
num_hidden_layers = clip_config.num_hidden_layers + \
|
||||
layer_idx + 1
|
||||
else:
|
||||
num_hidden_layers = layer_idx + 1
|
||||
|
||||
img_processor = CLIPVisionModel(
|
||||
clip_config, num_hidden_layers_override=num_hidden_layers)
|
||||
|
||||
return img_processor
|
||||
|
||||
|
||||
class Phi3VImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
data: Union[torch.Tensor, List[torch.Tensor]]
|
||||
@ -139,18 +156,8 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
|
||||
hidden_size = config.n_embd if hasattr(
|
||||
config, 'n_embd') else config.hidden_size
|
||||
|
||||
clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG
|
||||
self.layer_idx = config.img_processor.get('layer_idx', -2)
|
||||
self.img_processor = _init_img_processor(config)
|
||||
|
||||
# Initialize the CLIP only up to the required feature layer
|
||||
if self.layer_idx < 0:
|
||||
num_hidden_layers = clip_config.num_hidden_layers + \
|
||||
self.layer_idx + 1
|
||||
else:
|
||||
num_hidden_layers = self.layer_idx + 1
|
||||
|
||||
self.img_processor = CLIPVisionModel(
|
||||
clip_config, num_hidden_layers_override=num_hidden_layers)
|
||||
image_dim_out = config.img_processor['image_dim_out']
|
||||
self.num_img_tokens = config.img_processor['num_img_tokens']
|
||||
|
||||
@ -656,23 +663,27 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal):
|
||||
(".gate_up_proj", ".gate_proj", 0),
|
||||
(".gate_up_proj", ".up_proj", 1),
|
||||
]
|
||||
|
||||
# TODO(ChristopherCho): This is a temporary fix to load
|
||||
# the vision weights with CLIPVisionModel.load_weights()
|
||||
vision_weights = []
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
# post_layernorm is not needed in CLIPVisionModel
|
||||
if "vision_model.post_layernorm" in name:
|
||||
# Skip loading the img_processor weights since they are
|
||||
# loaded separately.
|
||||
if "vision_embed_tokens.img_processor" in name:
|
||||
vision_weights.append((name, loaded_weight))
|
||||
continue
|
||||
|
||||
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
|
||||
if key_to_modify in name:
|
||||
name = name.replace(key_to_modify, new_key)
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
# We only do sharding for language model
|
||||
# and not vision model for now.
|
||||
if "vision_embed_tokens" in name and self.vision_embed_tokens:
|
||||
continue
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
param = params_dict[name.replace(weight_name, param_name)]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
@ -686,3 +697,11 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
# We use regex to extract the sub-module name
|
||||
# from "model.vision_embed_tokens.img_processor.*"
|
||||
vision_weights = [
|
||||
(re.search(r"vision_embed_tokens\.img_processor\.(.*)",
|
||||
n).group(1), w) for n, w in vision_weights
|
||||
]
|
||||
self.vision_embed_tokens.img_processor.load_weights(vision_weights)
|
||||
|
@ -9,12 +9,10 @@ import torch
|
||||
from PIL import Image
|
||||
from torch import nn
|
||||
from transformers import SiglipVisionConfig
|
||||
from transformers.models.siglip.modeling_siglip import SiglipAttention
|
||||
from vllm_flash_attn import flash_attn_func
|
||||
from xformers.ops import memory_efficient_attention
|
||||
from xformers import ops as xops
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.inputs import LLMInputs
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
@ -221,9 +219,7 @@ class SiglipVisionEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
# NOTE: Not used - kept for later when we TP the ViT
|
||||
# TODO(ChristopherCho): Implement TP version of Attention
|
||||
class SiglipTPAttention(nn.Module):
|
||||
class SiglipAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -233,38 +229,30 @@ class SiglipTPAttention(nn.Module):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = config.num_attention_heads
|
||||
if self.total_num_heads % tp_size != 0:
|
||||
raise ValueError(
|
||||
f"Number of attention heads ({self.total_num_heads}) "
|
||||
"must be divisible by the tensor model parallel size"
|
||||
f" ({tp_size}).")
|
||||
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.head_dim = self.embed_dim // self.total_num_heads
|
||||
if self.head_dim * self.total_num_heads != self.embed_dim:
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
if self.head_dim * self.num_heads != self.embed_dim:
|
||||
raise ValueError(f"embed_dim must be divisible by num_heads (got "
|
||||
"`embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||
f" {self.num_heads}).")
|
||||
self.qkv_size = self.num_heads * self.head_dim
|
||||
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.dropout = config.attention_dropout
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size=self.embed_dim,
|
||||
head_size=self.head_dim,
|
||||
total_num_heads=self.total_num_heads,
|
||||
total_num_heads=self.num_heads,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.out_proj = RowParallelLinear(
|
||||
input_size=self.embed_dim,
|
||||
output_size=self.embed_dim,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.attn_fn = self._basic_attention_forward
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -274,163 +262,29 @@ class SiglipTPAttention(nn.Module):
|
||||
batch_size, q_len, _ = hidden_states.size()
|
||||
|
||||
qkv_states, _ = self.qkv_proj(hidden_states)
|
||||
query_states, key_states, value_states = qkv_states.split(
|
||||
[self.qkv_size] * 3, dim=-1)
|
||||
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
|
||||
|
||||
attn_output = self.attn_fn(
|
||||
q=query_states,
|
||||
k=key_states,
|
||||
v=value_states,
|
||||
batch_size=batch_size,
|
||||
q_len=q_len,
|
||||
)
|
||||
query_states = query_states.view(batch_size, q_len,
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim)
|
||||
key_states = key_states.view(batch_size, q_len,
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim)
|
||||
value_states = value_states.view(batch_size, q_len,
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim)
|
||||
|
||||
attn_output, _ = self.out_proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
def _basic_attention_forward(self, q, k, v, batch_size, q_len):
|
||||
q = q.view(batch_size, q_len, self.num_heads,
|
||||
self.head_dim).transpose(1, 2)
|
||||
k = k.view(batch_size, q_len, self.num_heads,
|
||||
self.head_dim).transpose(1, 2)
|
||||
v = v.view(batch_size, q_len, self.num_heads,
|
||||
self.head_dim).transpose(1, 2)
|
||||
|
||||
k_v_seq_len = k.shape[-2]
|
||||
attn_weights = torch.matmul(q, k.transpose(2, 3)) * self.scale
|
||||
|
||||
if attn_weights.size() != (
|
||||
batch_size,
|
||||
self.num_heads,
|
||||
q_len,
|
||||
k_v_seq_len,
|
||||
):
|
||||
raise ValueError(
|
||||
"Attention weights should be of size "
|
||||
f"{(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
|
||||
f" {attn_weights.size()}")
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights,
|
||||
dim=-1,
|
||||
dtype=torch.float32).to(q.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights,
|
||||
p=self.dropout,
|
||||
training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, v)
|
||||
|
||||
if attn_output.size() != (
|
||||
batch_size,
|
||||
self.num_heads,
|
||||
q_len,
|
||||
self.head_dim,
|
||||
):
|
||||
raise ValueError(
|
||||
"`attn_output` should be of size "
|
||||
f"{(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}")
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
|
||||
out = xops.memory_efficient_attention_forward(query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
p=self.dropout,
|
||||
scale=self.scale)
|
||||
out = out.view(batch_size, q_len, -1)
|
||||
attn_output, _ = self.out_proj(out)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
# NOTE: Not used - kept for later when we TP the ViT
|
||||
# TODO(ChristopherCho): flash_attn_func is not working properly.
|
||||
# It constantly throws a CUDA error.
|
||||
class SiglipFlashAttention2(SiglipTPAttention):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.attn_fn = self._flash_attention_forward
|
||||
|
||||
# Ported from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L449
|
||||
# and https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/modeling_flash_attention_utils.py#L133
|
||||
def _flash_attention_forward(self, q, k, v, batch_size, q_len, *args,
|
||||
**kwargs):
|
||||
"""Implements the multihead softmax attention.
|
||||
Arguments
|
||||
---------
|
||||
q, k, v: The tensor containing the
|
||||
query, key, and value. (B, S, H, D)
|
||||
"""
|
||||
|
||||
q = q.view(batch_size, q_len, self.num_heads, self.head_dim)
|
||||
k = k.view(batch_size, q_len, self.num_heads, self.head_dim)
|
||||
v = v.view(batch_size, q_len, self.num_heads, self.head_dim)
|
||||
|
||||
attn_output = flash_attn_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
dropout_p=self.dropout,
|
||||
causal=False,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(batch_size, q_len,
|
||||
self.embed_dim).contiguous()
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
# NOTE: Not used - kept for later when we TP the ViT
|
||||
class SiglipSdpaAttention(SiglipTPAttention):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.is_causal = False
|
||||
self.attn_fn = self._sdpa_attention_forward
|
||||
|
||||
def _sdpa_attention_forward(self, q, k, v, batch_size, q_len):
|
||||
q = q.view(batch_size, q_len, self.num_heads,
|
||||
self.head_dim).transpose(1, 2)
|
||||
k = k.view(batch_size, q_len, self.num_heads,
|
||||
self.head_dim).transpose(1, 2)
|
||||
v = v.view(batch_size, q_len, self.num_heads,
|
||||
self.head_dim).transpose(1, 2)
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, dropout_p=self.dropout, is_causal=False, scale=self.scale)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.view(batch_size, q_len, self.embed_dim)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
# NOTE: Not used - kept for later when we TP the ViT
|
||||
class SiglipxFormersAttention(SiglipTPAttention):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.attn_fn = self._xformers_attention_forward
|
||||
|
||||
def _xformers_attention_forward(self, q, k, v, batch_size, q_len):
|
||||
q = q.view(batch_size, q_len, self.num_heads, self.head_dim)
|
||||
k = k.view(batch_size, q_len, self.num_heads, self.head_dim)
|
||||
v = v.view(batch_size, q_len, self.num_heads, self.head_dim)
|
||||
|
||||
attn_output = memory_efficient_attention(q,
|
||||
k,
|
||||
v,
|
||||
p=0.0,
|
||||
scale=self.scale)
|
||||
attn_output = attn_output.reshape(batch_size, q_len,
|
||||
self.embed_dim).contiguous()
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
# NOTE: Not used - kept for later when we TP the ViT
|
||||
SIGLIP_ATTENTION_CLASSES = {
|
||||
"eager": SiglipTPAttention,
|
||||
"flash_attention_2": SiglipFlashAttention2,
|
||||
"sdpa": SiglipSdpaAttention,
|
||||
"xformers": SiglipxFormersAttention,
|
||||
}
|
||||
|
||||
|
||||
class SiglipMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
@ -473,8 +327,7 @@ class SiglipEncoderLayer(nn.Module):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
|
||||
# TODO(ChristopherCho): use TP'ed Attention block
|
||||
self.self_attn = SiglipAttention(config)
|
||||
self.self_attn = SiglipAttention(config, quant_config=quant_config)
|
||||
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
|
||||
eps=config.layer_norm_eps)
|
||||
self.mlp = SiglipMLP(
|
||||
@ -491,7 +344,7 @@ class SiglipEncoderLayer(nn.Module):
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
|
||||
hidden_states = self.self_attn(hidden_states=hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
|
Reference in New Issue
Block a user