Files
vllm/vllm/model_executor/models/transformers.py
Raushan Turganbay ab5e7d93f4 [Bugfix] Fix mrope in Transformers Backend (#26087)
Signed-off-by: raushan <raushan@huggingface.co>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-10-06 11:40:50 +00:00

971 lines
37 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2024 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wrapper around `transformers` models"""
from collections.abc import Iterable, Mapping
from contextlib import contextmanager
from pathlib import Path
from typing import Literal, Optional, Union
import regex as re
import torch
import transformers
from packaging.version import Version
from torch import nn
from transformers import AutoModel, BatchFeature, PretrainedConfig, PreTrainedModel
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (
CacheConfig,
DeviceConfig,
ModelConfig,
ParallelConfig,
VllmConfig,
)
from vllm.config.multimodal import BaseDummyOptions
from vllm.config.utils import getattr_iter
from vllm.distributed import get_pp_group, get_tp_group
from vllm.distributed.utils import get_pp_indices
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalInputs,
MultiModalUUIDDict,
PlaceholderRange,
)
from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems
from vllm.multimodal.processing import BaseMultiModalProcessor, BaseProcessingInfo
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from .interfaces import (
MultiModalEmbeddings,
SupportsLoRA,
SupportsMultiModal,
SupportsPP,
SupportsQuant,
)
from .utils import (
AutoWeightsLoader,
PPMissingLayer,
WeightsMapper,
make_empty_intermediate_tensors_factory,
maybe_prefix,
)
logger = init_logger(__name__)
def get_feature_request_tip(
model: str,
trust_remote_code: bool,
) -> str:
hf_url = f"a discussion at https://huggingface.co/{model}/discussions/new"
gh_url = "an issue at https://github.com/huggingface/transformers/issues/new/choose"
url = hf_url if trust_remote_code else gh_url
prefix = f"Please open {url} to request support for this feature. "
if Path(model).exists():
prefix = ""
doc_url = "https://docs.vllm.ai/en/latest/models/supported_models.html#writing-custom-models"
tip = f"See {doc_url} for instructions on how to add support yourself."
return f"{prefix}{tip}"
def vllm_flash_attention_forward(
# Transformers args
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor,
# Transformers kwargs
scaling: Optional[float] = None,
# vLLM kwargs
attention_instances: Optional[dict[Attention]] = None,
**kwargs,
):
self_attn = attention_instances[module.layer_idx]
if scaling is not None:
self_attn.impl.scale = float(scaling)
hidden = query.shape[-2]
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
query, key, value = (x.reshape(hidden, -1) for x in (query, key, value))
return self_attn.forward(query, key, value), None
ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward
def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):
logger.debug("%s: %s -> %s", name, old_module, new_module)
def can_enable_torch_compile(vllm_config: VllmConfig) -> bool:
"""
Callable to be passed to `@support_torch_compile`'s `enable_if` argument.
Defaults to `True` but is disabled in the following situations:
- The model uses dynamic rope scaling.
"""
enable = True
text_config = vllm_config.model_config.hf_config.get_text_config()
# Dynamic rope scaling is not compatible with torch.compile
rope_scaling: dict = getattr(text_config, "rope_scaling", None) or {}
if rope_scaling.get("rope_type") == "dynamic":
enable = False
return enable
Style = Literal["colwise", "colwise_rep", "rowwise", "rowwise_rep", "replicate"]
def replace_linear_class(
linear: nn.Linear,
style: Style = "replicate",
quant_config: Optional[QuantizationConfig] = None,
*,
prefix: str = "",
) -> Union[ColumnParallelLinear, RowParallelLinear, ReplicatedLinear]:
"""
Replace nn.Linear with one of vLLM's tensor parallel linear classes.
Args:
linear: `nn.Linear` to be replaced.
style: Tensor parallel style of the new linear, e.g. "colwise".
quant_config: Quantization config for the new linear.
Returns:
The new linear.
"""
if not isinstance(style, str):
raise ValueError(f"Unsupported parallel style type {type(style)}, expected str")
vllm_linear_cls, vllm_linear_kwargs = {
"colwise": (ColumnParallelLinear, {}),
"colwise_rep": (ColumnParallelLinear, {"gather_output": True}),
"rowwise": (RowParallelLinear, {}),
"rowwise_rep": (RowParallelLinear, {"input_is_parallel": False}),
"replicate": (ReplicatedLinear, {}),
}.get(style, (ReplicatedLinear, {}))
return vllm_linear_cls(
input_size=linear.in_features,
output_size=linear.out_features,
bias=linear.bias is not None,
quant_config=quant_config,
prefix=prefix,
return_bias=False,
**vllm_linear_kwargs,
)
def replace_rms_norm_class(rms_norm: nn.Module, hidden_size: int) -> RMSNorm:
"""Replace a Transformers RMSNorm with vLLM's RMSNorm.
This method assumes:
- Weight is stored as `weight`.
- Epsilon is stored as `eps` or `variance_epsilon`.
- `with_scale` indicates whether the layer has a weight (Gemma3n only).
- `var_hidden_size` is only ever used for Intern vision encoder in vLLM
and Transformers doesn't appear to have the same concept.
"""
kwargs = {
"hidden_size": hidden_size,
"eps": getattr_iter(rms_norm, ("eps", "variance_epsilon"), 1e-6),
"has_weight": getattr(rms_norm, "with_scale", True),
}
if (weight := getattr(rms_norm, "weight", None)) is not None:
# If weight is a Parameter, get its data tensor
weight = getattr(weight, "data", weight)
kwargs["dtype"] = weight.dtype
else:
# No weight, fall back to weightless RMSNorm
kwargs["has_weight"] = False
return RMSNorm(**kwargs)
# Copied from `accelerate`
@contextmanager
def init_on_device_without_buffers(device: torch.device):
"""
A context manager under which models are initialized with all
parameters on the specified device. However buffers are not
initialized on specified device.
Args:
device (`torch.device`):
Device to initialize all parameters on.
"""
old_register_parameter = nn.Module.register_parameter
def register_empty_parameter(module, name, param):
old_register_parameter(module, name, param)
if param is not None:
param_cls = type(module._parameters[name])
kwargs = module._parameters[name].__dict__
kwargs["requires_grad"] = param.requires_grad
module._parameters[name] = param_cls(
module._parameters[name].to(device), **kwargs
)
tensor_constructors_to_patch = {}
def patch_tensor_constructor(fn):
def wrapper(*args, **kwargs):
kwargs["device"] = device
return fn(*args, **kwargs)
return wrapper
try:
nn.Module.register_parameter = register_empty_parameter
for torch_function_name in tensor_constructors_to_patch:
setattr(
torch,
torch_function_name,
patch_tensor_constructor(getattr(torch, torch_function_name)),
)
yield
finally:
nn.Module.register_parameter = old_register_parameter
for (
torch_function_name,
old_torch_function,
) in tensor_constructors_to_patch.items():
setattr(torch, torch_function_name, old_torch_function)
class MultiModalProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self):
return {"image": None}
def get_mm_max_tokens_per_item(self, seq_len, mm_counts):
return {"image": self.get_max_image_tokens()}
def get_max_image_tokens(self) -> int:
width, height = self.get_max_image_size()
processor = self.get_hf_processor()
multimodal_config = self.ctx.model_config.multimodal_config
mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {}
mm_tokens = processor._get_num_multimodal_tokens(
image_sizes=([height, width],), **mm_processor_kwargs
)
image_tokens = mm_tokens["num_image_tokens"][0]
return image_tokens
def get_max_image_size(self):
return 10_000, 10_000 # hardcode for arbitrary very large size
class MultiModalDummyInputsBuilder(BaseDummyInputsBuilder[MultiModalProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
processor = self.info.get_hf_processor()
if "gemma3" in processor.__class__.__name__.lower():
image_token = processor.boi_token
else:
image_token = getattr(processor, "image_token", "")
return image_token * num_images
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
target_width, target_height = self.info.get_max_image_size()
image_overrides = mm_options.get("image") if mm_options else None
return {
"image": self._get_dummy_images(
width=target_width,
height=target_height,
num_images=num_images,
overrides=image_overrides,
),
}
class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargsItems,
):
"""
Given the original multi-modal items for this modality
and HF-processed data, output the updates to perform.
The information returned by this method is used to update token inputs
which bypass the HF processor. It is also used to update the output of
HF processor if the HF process does not apply prompt updates to text
inputs.
Moreover, this information is critical to determine the token positions
in order to construct :class:`~vllm-multimodal.input.PlaceholderRange`
for each multi-modal item.
"""
return None
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
# HF Processors always return a mask but vLLM doesn't need it
hf_inputs.pop("attention_mask", None)
num_image_patches = hf_inputs.get("num_image_patches")
mm_fields = {
key: MultiModalFieldConfig.flat_from_sizes("image", num_image_patches)
for key in hf_inputs
}
mm_fields["image_embeds"] = MultiModalFieldConfig.flat_from_sizes(
"image", num_image_patches
)
# Keep these as batched, as they always have batch size as first dim
mm_fields["image_grid_thw"] = MultiModalFieldConfig.batched("image")
mm_fields["video_grid_thw"] = MultiModalFieldConfig.batched("image")
mm_fields["num_image_patches"] = MultiModalFieldConfig.batched("image")
return mm_fields
def _get_hf_mm_data(
self,
mm_items: MultiModalDataItems,
) -> tuple[Mapping[str, object], Mapping[str, object]]:
"""
In contrast to the base class, this method always adds
`return_mm_token_type_ids` to the processor data
"""
processor_data, passthrough_data = super()._get_hf_mm_data(mm_items)
processor_data["return_mm_token_type_ids"] = True
return processor_data, passthrough_data
def apply(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Optional[Mapping[str, object]] = None,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> MultiModalInputs:
"""
Process multi-modal inputs to be used in vLLM.
Apply HF Processor on prompt text and multi-modal data together,
outputting token IDs and processed tensors.
"""
if tokenization_kwargs is None:
tokenization_kwargs = {}
mm_items = self._to_mm_items(mm_data)
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
if not isinstance(prompt, str):
# the prompt is the tokenized ids which is not supported
# by the hf_processor, which is why we would need to decode the ids
# into string
prompt = hf_processor.decode(prompt)
# Bypass cached processor and always apply to the full set of mm inputs
# NOTE: we can't just set caching=False because base class method
# transforms outputs to `MultiModalKwargs` which is not going to
# work for Transformers. We have a lot of logic tied to
# `mm_tokens_per_modality` below
prompt_ids, processed_data, _ = self._apply_hf_processor_text_mm(
prompt_text=prompt,
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
# For gemma3 we check `token_type_ids` as the key
token_type_key = (
"mm_token_type_ids"
if "mm_token_type_ids" in processed_data
else "token_type_ids"
)
mm_token_type_ids = processed_data.pop(token_type_key)
# We can infer vLLM style placeholder from token type ids, if we split
# it for each input `mm_data`.
mm_positions = torch.where(mm_token_type_ids == 1)[1]
images = mm_items.get_items("image", ImageProcessorItems)
multimodal_config = self.info.ctx.model_config.multimodal_config
mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {}
image_sizes = []
for item_idx in range(len(images)):
image_size = images.get_image_size(item_idx)
image_sizes.append((image_size.height, image_size.width))
mm_tokens_per_modality = hf_processor._get_num_multimodal_tokens(
image_sizes=image_sizes, **mm_processor_kwargs
)
mm_placeholders = {}
split_sizes = mm_tokens_per_modality["num_image_tokens"]
if split_sizes:
chunked_mm_positions = torch.split(mm_positions, split_sizes)
mm_tokens = torch.tensor(prompt_ids)[mm_token_type_ids[0].bool()]
chunked_mm_tokens = torch.split(mm_tokens, split_sizes)
ranges = [
PlaceholderRange(
offset=positions[0].item(),
length=positions.shape[0],
is_embed=(mm_tokens == hf_processor.image_token_id).bool(),
)
for positions, mm_tokens in zip(chunked_mm_positions, chunked_mm_tokens)
]
mm_placeholders = {"image": ranges}
processed_data["num_image_patches"] = torch.tensor(
mm_tokens_per_modality["num_image_patches"]
)
mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
processed_data,
self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
)
# Use overrides if provided; fallback to data-dependent hashing.
mm_hashes = self._hash_mm_items(
mm_items, hf_processor_mm_kwargs, tokenization_kwargs, mm_uuids=mm_uuids
)
return MultiModalInputs(
type="multimodal",
prompt_token_ids=prompt_ids,
mm_kwargs=mm_kwargs,
mm_hashes=mm_hashes,
mm_placeholders=mm_placeholders,
)
class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
embedding_padding_modules = ["lm_head"]
embedding_modules = ["embed_tokens"] # TODO transformers will have a util to get it
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
logger.info("Using Transformers backend.")
self.config: PretrainedConfig = vllm_config.model_config.hf_config
self.text_config: PretrainedConfig = self.config.get_text_config()
self.cache_config: CacheConfig = vllm_config.cache_config
self.device_config: DeviceConfig = vllm_config.device_config
self.model_config: ModelConfig = vllm_config.model_config
self.parallel_config: ParallelConfig = vllm_config.parallel_config
self.quant_config: Optional[QuantizationConfig] = vllm_config.quant_config
self.pp_group = get_pp_group()
self.tp_group = get_tp_group()
# Weights to skip in `self.load_weights`
self.skip_prefixes: list[str] = []
"""Skip loading weights whose qualname starts with these prefixes."""
self.skip_substrs: list[str] = []
"""Skip loading weights whose qualname contains these substrings."""
self.ignore_unexpected_prefixes: list[str] = []
"""Ignore unexpected weights whose qualname starts with these prefixes.
"""
self.ignore_unexpected_suffixes: list[str] = []
"""Ignore unexpected weights whose qualname ends with these suffixes."""
if self.quant_config:
quant_method_name = self.quant_config.get_name()
# Check for unsupported quantization methods.
if quant_method_name == "mxfp4":
raise NotImplementedError(
"Transformers backend does not support MXFP4 quantization yet."
)
# Skip loading extra bias for GPTQ models.
if "gptq" in quant_method_name:
self.ignore_unexpected_suffixes.append(".bias")
# Set correct attn and init on "meta" to delay allocating GPU tensors
self.text_config._attn_implementation = "vllm"
with init_on_device_without_buffers("meta"):
self.model: PreTrainedModel = AutoModel.from_config(
self.config,
torch_dtype=self.model_config.dtype,
trust_remote_code=self.model_config.trust_remote_code,
)
# Remove layers not on this pipeline parallel rank
self.pipeline_parallel()
# Substitute remaining layers with vLLM's layers as needed
self.recursive_replace()
# Create attention instances for KV cache allocation
self.attention_instances = self.create_attention_instances()
# Input embeddings
if not isinstance(self.model.get_input_embeddings(), PPMissingLayer):
names = ("embedding_size", "hidden_size")
embedding_dim = getattr_iter(self.text_config, names, None)
assert embedding_dim is not None
self.model.set_input_embeddings(
VocabParallelEmbedding(
self.text_config.vocab_size,
embedding_dim=embedding_dim,
org_num_embeddings=self.text_config.vocab_size,
quant_config=self.quant_config,
)
)
# Initialize any parameters that have not had their modules replaced
self.init_parameters(self.model)
# Pipeline parallel intermediate tensors
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states"], self.text_config.hidden_size
)
def pipeline_parallel(self):
"""
Apply the model's pipeline parallelization plan.
"""
if self.pp_group.world_size <= 1:
return
if not self.model.supports_pp_plan:
tip = get_feature_request_tip(
self.model_config.model, self.model_config.trust_remote_code
)
raise ValueError(
f"{type(self.model)} does not support pipeline parallel. {tip}"
)
module_lists = []
module_list_idx = None
pp_plan = list(self.model._pp_plan.keys())
for i, name in enumerate(pp_plan):
if isinstance(getattr(self.model, name), nn.ModuleList):
module_lists.append(name)
module_list_idx = i
if len(module_lists) > 1:
raise ValueError(
"Pipeline parallel of models with multiple `ModuleList`s "
"in the base model are not supported yet!"
)
if module_list_idx is None:
raise ValueError(f"Could not find `ModuleList` in {type(self.model)}")
# Layers before module list
for name in pp_plan[:module_list_idx]:
if self.pp_group.is_first_rank or (
self.text_config.tie_word_embeddings and self.pp_group.is_last_rank
):
continue
setattr(self.model, name, PPMissingLayer())
# Module list
start_layer, end_layer = get_pp_indices(
self.text_config.num_hidden_layers,
self.pp_group.rank_in_group,
self.pp_group.world_size,
)
layers_name = pp_plan[module_list_idx]
layers = getattr(self.model, layers_name)
for i in range(len(layers)):
if start_layer <= i and i < end_layer:
continue
layers[i] = PPMissingLayer()
# Layers after module list
for name in pp_plan[module_list_idx + 1 :]:
# Modules that should be on last rank
if not self.pp_group.is_last_rank:
setattr(self.model, name, PPMissingLayer())
def recursive_replace(self):
"""Recursively replace modules in the model as needed.
Currently, this replaces:
- `nn.Linear` with vLLM's tensor parallel linear classes
- `*RMSNorm` with vLLM's `RMSNorm`
"""
tp_plan = self.model.tp_plan
if not tp_plan and self.tp_group.world_size > 1:
tip = get_feature_request_tip(
self.model_config.model, self.model_config.trust_remote_code
)
raise ValueError(
f"{type(self.model)} does not support tensor parallel. {tip}"
)
# Prefix the patterns because we always start from `self.model`
tp_plan = {maybe_prefix("model", k): v for k, v in tp_plan.items()}
def _recursive_replace(module: nn.Module, prefix: str):
for child_name, child_module in module.named_children():
new_module = child_module
qual_name = maybe_prefix(prefix, child_name)
if isinstance(child_module, nn.Linear):
generator = (p for p in tp_plan if re.match(p, qual_name))
pattern = next(generator, None)
# Some weight loaders expect all linear layers to inherit
# LinearBase, so we set a default style which causes any
# unspecified layers to be replaced with ReplicatedLinear
style = tp_plan.get(pattern, "replicate")
new_module = replace_linear_class(
child_module, style, self.quant_config, prefix=qual_name
)
# TODO(hmellor): Enable RMSNorm replacement once we have a way
# to choose RMSNorm vs GemmaRMSNorm
# elif child_module.__class__.__name__.endswith("RMSNorm"):
# new_module = replace_rms_norm_class(
# child_module, self.config.hidden_size)
else:
_recursive_replace(child_module, prefix=qual_name)
if new_module is not child_module:
setattr(module, child_name, new_module)
log_replacement(qual_name, child_module, new_module)
_recursive_replace(self.model, prefix="model")
def create_attention_instances(
self, attn_type: AttentionType = AttentionType.DECODER
) -> dict[int, Attention]:
"""
Create `Attention` instances to inform KV cache allocation.
"""
num_heads = self.model_config.get_num_attention_heads(self.parallel_config)
head_size = self.model_config.get_head_size()
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
start, end = get_pp_indices(
self.text_config.num_hidden_layers,
self.pp_group.rank_in_group,
self.pp_group.world_size,
)
attention_instances = {}
for i in range(start, end):
# Handle interleaved sliding window attention
per_layer_sliding_window = None
if (
hasattr(self.config, "layer_types")
and self.config.layer_types[i] == "sliding_attention"
):
per_layer_sliding_window = self.config.sliding_window
attention_instances[i] = Attention(
num_heads=num_heads,
head_size=head_size,
# NOTE: We use Llama scale as default, if it's set by
# Transformers, it's updated in vllm_flash_attention_forward
scale=head_size**-0.5,
num_kv_heads=num_kv_heads,
cache_config=self.cache_config,
quant_config=self.quant_config,
per_layer_sliding_window=per_layer_sliding_window,
prefix=f"{i}.attn",
attn_type=attn_type,
)
return attention_instances
def init_parameters(self, module: nn.Module, dtype: Optional[torch.dtype] = None):
"""
If a `parameter` is on the `meta` device, then its parent
`module` is the original module created by:
```python
with torch.device("meta"):
self.model: PreTrainedModel = AutoModel.from_config(...)
```
"""
def _init_parameters(module: nn.Module, dtype: Optional[torch.dtype]):
for name, param in module.named_parameters(recurse=False):
if param.device == torch.device("meta"):
new_param = nn.Parameter(
torch.empty_like(
param.data,
dtype=dtype or self.model_config.dtype,
device=self.device_config.device,
)
)
setattr(module, name, new_param)
for child in module.children():
_init_parameters(child, dtype)
_init_parameters(module, dtype)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if not self.pp_group.is_first_rank:
assert intermediate_tensors is not None
input_ids = None
inputs_embeds = intermediate_tensors["hidden_states"]
if input_ids is not None:
input_ids = input_ids[None, ...]
if inputs_embeds is not None:
inputs_embeds = inputs_embeds[None, ...]
if self.model_config.uses_mrope:
position_ids = positions[:, None]
else:
position_ids = positions[None, ...]
hidden_states = self.model(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
use_cache=False,
position_ids=position_ids,
attention_instances=self.attention_instances,
return_dict=False,
)[0][0, ...] # we remove batch dimension for now
if not self.pp_group.is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
return hidden_states
def load_weights(
self,
weights: Iterable[tuple[str, torch.Tensor]],
) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=self.skip_prefixes,
skip_substrs=self.skip_substrs,
ignore_unexpected_prefixes=self.ignore_unexpected_prefixes,
ignore_unexpected_suffixes=self.ignore_unexpected_suffixes,
)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def check_version(self, min_version: str, feature: str):
installed = Version(transformers.__version__)
required = Version(min_version)
if installed < required:
raise ImportError(
f"Transformers backend requires transformers>={required} "
f"for {feature}, but got {installed}"
)
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersForCausalLM(TransformersBase):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
# Tell `TransformersBase.load_weights` to skip
# `lm_head` if the model has tied word embeddings
if self.text_config.tie_word_embeddings:
self.skip_prefixes.append("lm_head.")
if self.pp_group.is_last_rank:
self.unpadded_vocab_size = self.text_config.vocab_size
self.lm_head = ParallelLMHead(
self.text_config.vocab_size,
self.text_config.hidden_size,
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
if self.text_config.tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(
self.model.get_input_embeddings()
)
logit_scale = getattr(self.text_config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(
self.unpadded_vocab_size, self.text_config.vocab_size, logit_scale
)
else:
self.lm_head = PPMissingLayer()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings()(input_ids)
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
@MULTIMODAL_REGISTRY.register_processor(
MultiModalProcessor,
info=MultiModalProcessingInfo,
dummy_inputs=MultiModalDummyInputsBuilder,
)
@support_torch_compile(
# set `positions` to last dim to support Qwen-mrope
dynamic_arg_dims={
"input_ids": 0,
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
},
enable_if=can_enable_torch_compile,
)
class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal):
merge_by_field_config = True
# Backwards compatibility for prev released models. State dicts back then
# had different formats and cannot be loaded with `AutoModel` mapping as is
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"language_model.model": "model.language_model",
"text_model.model": "model.text_model",
"vision_tower": "model.vision_tower",
"vqmodel": "model.vqmodel",
"visual": "model.visual",
"vision_model": "model.vision_model",
"vision_embed_tokens": "model.vision_embed_tokens",
"image_newline": "model.image_newline",
"multi_modal_projector": "model.multi_modal_projector",
"text_model.lm_head": "lm_head",
"language_model.lm_head": "lm_head",
# Qwen models used "model" as the name for the language model.
# Therefore, we must map each of submodule explicitly to avoid
# conflicts with newer models that use "model.language_model".
"model.embed_tokens": "model.language_model.embed_tokens",
"model.layers": "model.language_model.layers",
"model.norm": "model.language_model.norm",
}
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
self.dtype = vllm_config.model_config.dtype
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
model_output = super().forward(
input_ids, positions, intermediate_tensors, inputs_embeds
)
return model_output
def get_language_model(self) -> torch.nn.Module:
return self.model
def get_multimodal_embeddings(self, **kwargs):
pixel_values: Optional[torch.Tensor] = kwargs.pop("pixel_values", None)
image_embeds: Optional[torch.Tensor] = kwargs.pop("image_embeds", None)
# Model might use `image_patches` instead of `pixel_values`
if pixel_values is None:
pixel_values = kwargs.pop("image_patches", None)
if image_embeds is not None:
return image_embeds
if pixel_values is None:
return None
num_image_patches = kwargs.pop("num_image_patches")
if pixel_values is not None:
vision_embeddings = self.model.get_image_features(pixel_values, **kwargs)
if isinstance(vision_embeddings, torch.Tensor):
if vision_embeddings.ndim == 2:
vision_embeddings = vision_embeddings.unsqueeze(0)
# Embeddings have to be 2D tensors of length `num_images`
# but transformers returns concat tensors if each patch
# is of different size. We split it back to make vLLM happy
vision_embeddings = torch.split(
vision_embeddings, num_image_patches.flatten().tolist()
)
vision_embeddings = [
embed.flatten(start_dim=0, end_dim=-2)
for embed in vision_embeddings
]
return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
"""
Apply token embeddings to `input_ids`.
If `multimodal_embeddings` is passed, scatter them into
`input_ids` according to the mask `is_multimodal`.
In case the multi-modal token IDs exceed the vocabulary size of
the language model, you can set `handle_oov_mm_token=False`
to avoid calling the language model's `get_input_embeddings` method
on those tokens.
"""
from .utils import _merge_multimodal_embeddings
inputs_embeds = self._get_text_embeddings(
input_ids,
self.model.get_input_embeddings(),
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
return inputs_embeds
if is_multimodal is None:
raise ValueError(
"`get_input_embeddings` now requires `is_multimodal` arg, "
"please update your model runner according to "
"https://github.com/vllm-project/vllm/pull/16229."
)
return _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)