Refactor Transformers backend to use mixins (#26906)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-16 22:50:39 +01:00
committed by GitHub
parent b2f78cbad4
commit fb5e10d3fb
17 changed files with 1510 additions and 1248 deletions

2
.github/CODEOWNERS vendored
View File

@ -57,7 +57,7 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
/tests/v1/offloading @ApostaC
# Transformers backend
/vllm/model_executor/models/transformers.py @hmellor
/vllm/model_executor/models/transformers @hmellor
/tests/models/test_transformers.py @hmellor
# Docs

View File

@ -912,11 +912,11 @@ _TRANSFORMERS_BACKEND_MODELS = {
"TransformersForCausalLM": _HfExamplesInfo(
"hmellor/Ilama-3.2-1B", trust_remote_code=True
),
"TransformersForMultimodalLM": _HfExamplesInfo("BAAI/Emu3-Chat-hf"),
"TransformersMultiModalForCausalLM": _HfExamplesInfo("BAAI/Emu3-Chat-hf"),
"TransformersMoEForCausalLM": _HfExamplesInfo(
"allenai/OLMoE-1B-7B-0924", min_transformers_version="4.57.0.dev0"
),
"TransformersMoEForMultimodalLM": _HfExamplesInfo(
"TransformersMultiModalMoEForCausalLM": _HfExamplesInfo(
"Qwen/Qwen3-VL-30B-A3B-Instruct", min_transformers_version="4.57.0.dev0"
),
"TransformersMoEEmbeddingModel": _HfExamplesInfo(
@ -925,6 +925,10 @@ _TRANSFORMERS_BACKEND_MODELS = {
"TransformersMoEForSequenceClassification": _HfExamplesInfo(
"Qwen/Qwen3-30B-A3B", min_transformers_version="4.57.0.dev0"
),
"TransformersMultiModalEmbeddingModel": _HfExamplesInfo("google/gemma-3-4b-it"),
"TransformersMultiModalForSequenceClassification": _HfExamplesInfo(
"google/gemma-3-4b-it"
),
}
_EXAMPLE_MODELS = {

View File

@ -37,7 +37,7 @@ MINIMAL_MODEL_ARCH_LIST = [
"JinaVLForRanking",
"InternVLChatModel",
"InternLM2ForRewardModel",
"TransformersForMultimodalLM",
"TransformersMultiModalForCausalLM",
"PrithviGeoSpatialMAE",
"UltravoxModel",
"DeepSeekMTPModel",

View File

@ -211,11 +211,7 @@ def test_embed_loading(vllm_runner, model):
def test_pooling(hf_runner, vllm_runner, example_prompts, arch):
model = get_model(arch)
vllm_kwargs = dict(
max_model_len=None,
model_impl="transformers",
compilation_config=dict(cudagraph_capture_sizes=[8]),
)
vllm_kwargs = dict(max_model_len=None, model_impl="transformers")
hf_kwargs = dict()
if arch == "TransformersEmbeddingModel":

View File

@ -147,6 +147,10 @@ class ModelConfig:
seed: int | None = None
"""Random seed for reproducibility. Initialized to None in V0, but
initialized to 0 in V1."""
hf_config: PretrainedConfig = field(init=False)
"""The Hugging Face config of the model."""
hf_text_config: PretrainedConfig = field(init=False)
"""The Hugging Face config of the text model (same as hf_config for text models)."""
hf_config_path: str | None = None
"""Name or path of the Hugging Face config to use. If unspecified, model
name or path will be used."""
@ -771,8 +775,10 @@ class ModelConfig:
def _get_transformers_backend_cls(self) -> str:
"""Determine which Transformers backend class will be used if
`model_impl` is set to `transformers` or `auto`."""
prefix = "Transformers"
prefix += "MoE" if self.get_num_experts() > 1 else ""
cls = "Transformers"
# If 'hf_config != hf_text_config' it's a nested config, i.e. multimodal
cls += "MultiModal" if self.hf_config != self.hf_text_config else ""
cls += "MoE" if self.get_num_experts() > 1 else ""
# Check if the architecture we're wrapping has defaults
runner = None
convert = None
@ -788,18 +794,15 @@ class ModelConfig:
runner = "generate"
if convert in {None, "none"}:
convert = "embed"
# Resolve Transformers backend pooling classes
# Resolve Transformers backend task
if runner == "pooling":
if convert == "embed":
return prefix + "EmbeddingModel"
return cls + "EmbeddingModel"
if convert == "classify":
return prefix + "ForSequenceClassification"
# Resolve Transformers backend generate classes
if self.hf_config != self.hf_text_config:
# If 'hf_text_config' is the same as 'hf_config'. If not, it is
# probably a composite config, i.e. multimodal
return prefix + "ForMultimodalLM"
return prefix + "ForCausalLM"
return cls + "ForSequenceClassification"
else:
cls += "ForCausalLM"
return cls
def using_transformers_backend(self) -> bool:
"""Check if the model is using the Transformers backend class."""

View File

@ -19,7 +19,7 @@ from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.models.transformers import replace_linear_class
from vllm.model_executor.models.transformers.utils import replace_linear_class
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
MultiModalDataDict,

View File

@ -401,32 +401,44 @@ _TRANSFORMERS_SUPPORTED_MODELS = {
# Text generation models
"SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
# Multimodal models
"Emu3ForConditionalGeneration": ("transformers", "TransformersForMultimodalLM"), # noqa: E501
"Emu3ForConditionalGeneration": (
"transformers",
"TransformersMultiModalForCausalLM",
),
}
_TRANSFORMERS_BACKEND_MODELS = {
# Text generation models
"TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
"TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501
"TransformersMoEForCausalLM": ("transformers_moe", "TransformersMoEForCausalLM"), # noqa: E501
"TransformersMoEForMultimodalLM": (
"transformers_moe",
"TransformersMoEForMultimodalLM",
"TransformersMoEForCausalLM": ("transformers", "TransformersMoEForCausalLM"),
# Multimodal models
"TransformersMultiModalForCausalLM": (
"transformers",
"TransformersMultiModalForCausalLM",
),
"TransformersEmbeddingModel": (
"transformers_pooling",
"TransformersEmbeddingModel",
"TransformersMultiModalMoEForCausalLM": (
"transformers",
"TransformersMultiModalMoEForCausalLM",
),
# Embedding models
"TransformersEmbeddingModel": ("transformers", "TransformersEmbeddingModel"),
"TransformersMoEEmbeddingModel": ("transformers", "TransformersMoEEmbeddingModel"),
"TransformersMultiModalEmbeddingModel": (
"transformers",
"TransformersMultiModalEmbeddingModel",
),
# Sequence classification models
"TransformersForSequenceClassification": (
"transformers_pooling",
"transformers",
"TransformersForSequenceClassification",
),
"TransformersMoEForSequenceClassification": (
"transformers_pooling",
"transformers",
"TransformersMoEForSequenceClassification",
),
"TransformersMoEEmbeddingModel": (
"transformers_pooling",
"TransformersMoEEmbeddingModel",
"TransformersMultiModalForSequenceClassification": (
"transformers",
"TransformersMultiModalForSequenceClassification",
),
}

View File

@ -1,961 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2024 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wrapper around `transformers` models"""
from collections.abc import Iterable, Mapping
from contextlib import contextmanager
from pathlib import Path
from typing import Literal
import regex as re
import torch
import transformers
from packaging.version import Version
from torch import nn
from transformers import AutoModel, BatchFeature, PretrainedConfig, PreTrainedModel
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (
CacheConfig,
DeviceConfig,
ModelConfig,
ParallelConfig,
VllmConfig,
)
from vllm.config.multimodal import BaseDummyOptions
from vllm.config.utils import getattr_iter
from vllm.distributed import get_pp_group, get_tp_group
from vllm.distributed.utils import get_pp_indices
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalInputs,
MultiModalUUIDDict,
PlaceholderRange,
)
from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems
from vllm.multimodal.processing import BaseMultiModalProcessor, BaseProcessingInfo
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP, SupportsQuant
from .utils import (
AutoWeightsLoader,
PPMissingLayer,
WeightsMapper,
make_empty_intermediate_tensors_factory,
maybe_prefix,
)
logger = init_logger(__name__)
def get_feature_request_tip(
model: str,
trust_remote_code: bool,
) -> str:
hf_url = f"a discussion at https://huggingface.co/{model}/discussions/new"
gh_url = "an issue at https://github.com/huggingface/transformers/issues/new/choose"
url = hf_url if trust_remote_code else gh_url
prefix = f"Please open {url} to request support for this feature. "
if Path(model).exists():
prefix = ""
doc_url = "https://docs.vllm.ai/en/latest/models/supported_models.html#writing-custom-models"
tip = f"See {doc_url} for instructions on how to add support yourself."
return f"{prefix}{tip}"
def vllm_flash_attention_forward(
# Transformers args
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor,
# Transformers kwargs
scaling: float | None = None,
# vLLM kwargs
attention_instances: dict[Attention] | None = None,
**kwargs,
):
self_attn = attention_instances[module.layer_idx]
if scaling is not None:
self_attn.impl.scale = float(scaling)
hidden = query.shape[-2]
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
query, key, value = (x.reshape(hidden, -1) for x in (query, key, value))
return self_attn.forward(query, key, value), None
ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward
def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):
logger.debug("%s: %s -> %s", name, old_module, new_module)
def can_enable_torch_compile(vllm_config: VllmConfig) -> bool:
"""
Callable to be passed to `@support_torch_compile`'s `enable_if` argument.
Defaults to `True` but is disabled in the following situations:
- The model uses dynamic rope scaling.
"""
enable = True
text_config = vllm_config.model_config.hf_config.get_text_config()
# Dynamic rope scaling is not compatible with torch.compile
rope_scaling: dict = getattr(text_config, "rope_scaling", None) or {}
if rope_scaling.get("rope_type") == "dynamic":
enable = False
return enable
Style = Literal["colwise", "colwise_rep", "rowwise", "rowwise_rep", "replicate"]
def replace_linear_class(
linear: nn.Linear,
style: Style = "replicate",
quant_config: QuantizationConfig | None = None,
*,
prefix: str = "",
) -> ColumnParallelLinear | RowParallelLinear | ReplicatedLinear:
"""
Replace nn.Linear with one of vLLM's tensor parallel linear classes.
Args:
linear: `nn.Linear` to be replaced.
style: Tensor parallel style of the new linear, e.g. "colwise".
quant_config: Quantization config for the new linear.
Returns:
The new linear.
"""
if not isinstance(style, str):
raise ValueError(f"Unsupported parallel style type {type(style)}, expected str")
vllm_linear_cls, vllm_linear_kwargs = {
"colwise": (ColumnParallelLinear, {}),
"colwise_rep": (ColumnParallelLinear, {"gather_output": True}),
"rowwise": (RowParallelLinear, {}),
"rowwise_rep": (RowParallelLinear, {"input_is_parallel": False}),
"replicate": (ReplicatedLinear, {}),
}.get(style, (ReplicatedLinear, {}))
return vllm_linear_cls(
input_size=linear.in_features,
output_size=linear.out_features,
bias=linear.bias is not None,
quant_config=quant_config,
prefix=prefix,
return_bias=False,
**vllm_linear_kwargs,
)
def replace_rms_norm_class(rms_norm: nn.Module, hidden_size: int) -> RMSNorm:
"""Replace a Transformers RMSNorm with vLLM's RMSNorm.
This method assumes:
- Weight is stored as `weight`.
- Epsilon is stored as `eps` or `variance_epsilon`.
- `with_scale` indicates whether the layer has a weight (Gemma3n only).
- `var_hidden_size` is only ever used for Intern vision encoder in vLLM
and Transformers doesn't appear to have the same concept.
"""
eps = getattr_iter(rms_norm, ("eps", "variance_epsilon"), 1e-6)
kwargs = {"hidden_size": hidden_size, "eps": eps}
# Update hidden size if weight is available
weight_meta = getattr(rms_norm, "weight", None)
if weight_meta is not None:
kwargs["hidden_size"] = weight_meta.size(0)
# Check if weight is all zeros, which indicates GemmaRMSNorm
# We must create a new instance because rms_norm is on meta
try:
with torch.device("cpu"):
weight_test = getattr(rms_norm.__class__(1), "weight", None)
except Exception:
logger.warning(
"Failed to determine if RMSNorm weight is centered on zero or one. "
"Defaulting to one."
)
weight_test = None
if weight_test is not None and torch.all(weight_test == 0):
return GemmaRMSNorm(**kwargs)
# Otherwise assume it's a regular RMSNorm
kwargs["has_weight"] = getattr(rms_norm, "with_scale", True)
if weight_meta is not None:
kwargs["dtype"] = weight_meta.dtype
else:
# No weight, fall back to weightless RMSNorm
kwargs["has_weight"] = False
return RMSNorm(**kwargs)
# Copied from `accelerate`
@contextmanager
def init_on_device_without_buffers(device: torch.device):
"""
A context manager under which models are initialized with all
parameters on the specified device. However buffers are not
initialized on specified device.
Args:
device (`torch.device`):
Device to initialize all parameters on.
"""
old_register_parameter = nn.Module.register_parameter
def register_empty_parameter(module, name, param):
old_register_parameter(module, name, param)
if param is not None:
param_cls = type(module._parameters[name])
kwargs = module._parameters[name].__dict__
kwargs["requires_grad"] = param.requires_grad
module._parameters[name] = param_cls(
module._parameters[name].to(device), **kwargs
)
tensor_constructors_to_patch = {}
def patch_tensor_constructor(fn):
def wrapper(*args, **kwargs):
kwargs["device"] = device
return fn(*args, **kwargs)
return wrapper
try:
nn.Module.register_parameter = register_empty_parameter
for torch_function_name in tensor_constructors_to_patch:
setattr(
torch,
torch_function_name,
patch_tensor_constructor(getattr(torch, torch_function_name)),
)
yield
finally:
nn.Module.register_parameter = old_register_parameter
for (
torch_function_name,
old_torch_function,
) in tensor_constructors_to_patch.items():
setattr(torch, torch_function_name, old_torch_function)
class MultiModalProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self):
return {"image": None}
def get_mm_max_tokens_per_item(self, seq_len, mm_counts):
return {"image": self.get_max_image_tokens()}
def get_max_image_tokens(self) -> int:
width, height = self.get_max_image_size()
processor = self.get_hf_processor()
multimodal_config = self.ctx.model_config.multimodal_config
mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {}
mm_tokens = processor._get_num_multimodal_tokens(
image_sizes=([height, width],), **mm_processor_kwargs
)
image_tokens = mm_tokens["num_image_tokens"][0]
return image_tokens
def get_max_image_size(self):
return 10_000, 10_000 # hardcode for arbitrary very large size
class MultiModalDummyInputsBuilder(BaseDummyInputsBuilder[MultiModalProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
processor = self.info.get_hf_processor()
if "gemma3" in processor.__class__.__name__.lower():
image_token = processor.boi_token
else:
image_token = getattr(processor, "image_token", "")
return image_token * num_images
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
target_width, target_height = self.info.get_max_image_size()
image_overrides = mm_options.get("image") if mm_options else None
return {
"image": self._get_dummy_images(
width=target_width,
height=target_height,
num_images=num_images,
overrides=image_overrides,
),
}
class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargsItems,
):
"""
Given the original multi-modal items for this modality
and HF-processed data, output the updates to perform.
The information returned by this method is used to update token inputs
which bypass the HF processor. It is also used to update the output of
HF processor if the HF process does not apply prompt updates to text
inputs.
Moreover, this information is critical to determine the token positions
in order to construct :class:`~vllm-multimodal.input.PlaceholderRange`
for each multi-modal item.
"""
return None
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
# HF Processors always return a mask but vLLM doesn't need it
hf_inputs.pop("attention_mask", None)
num_image_patches = hf_inputs.get("num_image_patches")
mm_fields = {
key: MultiModalFieldConfig.flat_from_sizes("image", num_image_patches)
for key in hf_inputs
}
mm_fields["image_embeds"] = MultiModalFieldConfig.flat_from_sizes(
"image", num_image_patches
)
# Keep these as batched, as they always have batch size as first dim
mm_fields["image_grid_thw"] = MultiModalFieldConfig.batched("image")
mm_fields["video_grid_thw"] = MultiModalFieldConfig.batched("image")
mm_fields["num_image_patches"] = MultiModalFieldConfig.batched("image")
return mm_fields
def _get_hf_mm_data(
self,
mm_items: MultiModalDataItems,
) -> tuple[Mapping[str, object], Mapping[str, object]]:
"""
In contrast to the base class, this method always adds
`return_mm_token_type_ids` to the processor data
"""
processor_data, passthrough_data = super()._get_hf_mm_data(mm_items)
processor_data["return_mm_token_type_ids"] = True
return processor_data, passthrough_data
def apply(
self,
prompt: str | list[int],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object] | None = None,
mm_uuids: MultiModalUUIDDict | None = None,
) -> MultiModalInputs:
"""
Process multi-modal inputs to be used in vLLM.
Apply HF Processor on prompt text and multi-modal data together,
outputting token IDs and processed tensors.
"""
if tokenization_kwargs is None:
tokenization_kwargs = {}
mm_items = self._to_mm_items(mm_data)
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
if not isinstance(prompt, str):
# the prompt is the tokenized ids which is not supported
# by the hf_processor, which is why we would need to decode the ids
# into string
prompt = hf_processor.decode(prompt)
# Bypass cached processor and always apply to the full set of mm inputs
# NOTE: we can't just set caching=False because base class method
# transforms outputs to `MultiModalKwargs` which is not going to
# work for Transformers. We have a lot of logic tied to
# `mm_tokens_per_modality` below
prompt_ids, processed_data, _ = self._apply_hf_processor_text_mm(
prompt_text=prompt,
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
# For gemma3 we check `token_type_ids` as the key
token_type_key = (
"mm_token_type_ids"
if "mm_token_type_ids" in processed_data
else "token_type_ids"
)
mm_token_type_ids = processed_data.pop(token_type_key)
# We can infer vLLM style placeholder from token type ids, if we split
# it for each input `mm_data`.
mm_positions = torch.where(mm_token_type_ids == 1)[1]
images = mm_items.get_items("image", ImageProcessorItems)
multimodal_config = self.info.ctx.model_config.multimodal_config
mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {}
image_sizes = []
for item_idx in range(len(images)):
image_size = images.get_image_size(item_idx)
image_sizes.append((image_size.height, image_size.width))
mm_tokens_per_modality = hf_processor._get_num_multimodal_tokens(
image_sizes=image_sizes, **mm_processor_kwargs
)
mm_placeholders = {}
split_sizes = mm_tokens_per_modality["num_image_tokens"]
if split_sizes:
chunked_mm_positions = torch.split(mm_positions, split_sizes)
mm_tokens = torch.tensor(prompt_ids)[mm_token_type_ids[0].bool()]
chunked_mm_tokens = torch.split(mm_tokens, split_sizes)
ranges = [
PlaceholderRange(
offset=positions[0].item(),
length=positions.shape[0],
is_embed=(mm_tokens == hf_processor.image_token_id).bool(),
)
for positions, mm_tokens in zip(chunked_mm_positions, chunked_mm_tokens)
]
mm_placeholders = {"image": ranges}
processed_data["num_image_patches"] = torch.tensor(
mm_tokens_per_modality["num_image_patches"]
)
mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
processed_data,
self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
)
# Use overrides if provided; fallback to data-dependent hashing.
mm_hashes = self._hash_mm_items(
mm_items, hf_processor_mm_kwargs, tokenization_kwargs, mm_uuids=mm_uuids
)
return MultiModalInputs(
type="multimodal",
prompt_token_ids=prompt_ids,
mm_kwargs=mm_kwargs,
mm_hashes=mm_hashes,
mm_placeholders=mm_placeholders,
)
class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
embedding_padding_modules = ["lm_head"]
embedding_modules = ["embed_tokens"] # TODO transformers will have a util to get it
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
logger.info("Using Transformers backend.")
self.config: PretrainedConfig = vllm_config.model_config.hf_config
self.text_config: PretrainedConfig = self.config.get_text_config()
self.cache_config: CacheConfig = vllm_config.cache_config
self.device_config: DeviceConfig = vllm_config.device_config
self.model_config: ModelConfig = vllm_config.model_config
self.parallel_config: ParallelConfig = vllm_config.parallel_config
self.quant_config: QuantizationConfig | None = vllm_config.quant_config
self.pp_group = get_pp_group()
self.tp_group = get_tp_group()
# Weights to skip in `self.load_weights`
self.skip_prefixes: list[str] = []
"""Skip loading weights whose qualname starts with these prefixes."""
self.skip_substrs: list[str] = []
"""Skip loading weights whose qualname contains these substrings."""
self.ignore_unexpected_prefixes: list[str] = []
"""Ignore unexpected weights whose qualname starts with these prefixes.
"""
self.ignore_unexpected_suffixes: list[str] = []
"""Ignore unexpected weights whose qualname ends with these suffixes."""
if self.quant_config:
quant_method_name = self.quant_config.get_name()
# Check for unsupported quantization methods.
if quant_method_name == "mxfp4":
raise NotImplementedError(
"Transformers backend does not support MXFP4 quantization yet."
)
# Skip loading extra bias for GPTQ models.
if "gptq" in quant_method_name:
self.ignore_unexpected_suffixes.append(".bias")
# Set correct attn and init on "meta" to delay allocating GPU tensors
self.text_config._attn_implementation = "vllm"
with init_on_device_without_buffers("meta"):
self.model: PreTrainedModel = AutoModel.from_config(
self.config,
dtype=self.model_config.dtype,
trust_remote_code=self.model_config.trust_remote_code,
)
# Remove layers not on this pipeline parallel rank
self.pipeline_parallel()
# Substitute remaining layers with vLLM's layers as needed
self.recursive_replace()
# Create attention instances for KV cache allocation
self.attention_instances = self.create_attention_instances()
# Input embeddings
input_embeddings = self.model.get_input_embeddings()
if not isinstance(input_embeddings, PPMissingLayer):
# Some models use embedding scales
self.embed_scale = getattr(input_embeddings, "embed_scale", None)
names = ("embedding_size", "hidden_size")
embedding_dim = getattr_iter(self.text_config, names, None)
assert embedding_dim is not None
self.model.set_input_embeddings(
VocabParallelEmbedding(
self.text_config.vocab_size,
embedding_dim=embedding_dim,
org_num_embeddings=self.text_config.vocab_size,
quant_config=self.quant_config,
)
)
# Initialize any parameters that have not had their modules replaced
self.init_parameters(self.model)
# Pipeline parallel intermediate tensors
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states"], self.text_config.hidden_size
)
def pipeline_parallel(self):
"""
Apply the model's pipeline parallelization plan.
"""
if self.pp_group.world_size <= 1:
return
if not self.model.supports_pp_plan:
tip = get_feature_request_tip(
self.model_config.model, self.model_config.trust_remote_code
)
raise ValueError(
f"{type(self.model)} does not support pipeline parallel. {tip}"
)
module_lists = []
module_list_idx = None
pp_plan = list(self.model._pp_plan.keys())
for i, name in enumerate(pp_plan):
if isinstance(getattr(self.model, name), nn.ModuleList):
module_lists.append(name)
module_list_idx = i
if len(module_lists) > 1:
raise ValueError(
"Pipeline parallel of models with multiple `ModuleList`s "
"in the base model are not supported yet!"
)
if module_list_idx is None:
raise ValueError(f"Could not find `ModuleList` in {type(self.model)}")
# Layers before module list
for name in pp_plan[:module_list_idx]:
if self.pp_group.is_first_rank or (
self.text_config.tie_word_embeddings and self.pp_group.is_last_rank
):
continue
setattr(self.model, name, PPMissingLayer())
# Module list
start_layer, end_layer = get_pp_indices(
self.text_config.num_hidden_layers,
self.pp_group.rank_in_group,
self.pp_group.world_size,
)
layers_name = pp_plan[module_list_idx]
layers = getattr(self.model, layers_name)
for i in range(len(layers)):
if start_layer <= i and i < end_layer:
continue
layers[i] = PPMissingLayer()
# Layers after module list
for name in pp_plan[module_list_idx + 1 :]:
# Modules that should be on last rank
if not self.pp_group.is_last_rank:
setattr(self.model, name, PPMissingLayer())
def recursive_replace(self):
"""Recursively replace modules in the model as needed.
Currently, this replaces:
- `nn.Linear` with vLLM's tensor parallel linear classes
- `*RMSNorm` with vLLM's `RMSNorm`
"""
tp_plan = self.model.tp_plan
if not tp_plan and self.tp_group.world_size > 1:
tip = get_feature_request_tip(
self.model_config.model, self.model_config.trust_remote_code
)
raise ValueError(
f"{type(self.model)} does not support tensor parallel. {tip}"
)
# Prefix the patterns because we always start from `self.model`
tp_plan = {maybe_prefix("model", k): v for k, v in tp_plan.items()}
def _recursive_replace(module: nn.Module, prefix: str):
for child_name, child_module in module.named_children():
new_module = child_module
qual_name = maybe_prefix(prefix, child_name)
if isinstance(child_module, nn.Linear):
generator = (p for p in tp_plan if re.match(p, qual_name))
pattern = next(generator, None)
# Some weight loaders expect all linear layers to inherit
# LinearBase, so we set a default style which causes any
# unspecified layers to be replaced with ReplicatedLinear
style = tp_plan.get(pattern, "replicate")
new_module = replace_linear_class(
child_module, style, self.quant_config, prefix=qual_name
)
elif child_module.__class__.__name__.endswith("RMSNorm"):
new_module = replace_rms_norm_class(
child_module, self.text_config.hidden_size
)
else:
_recursive_replace(child_module, prefix=qual_name)
if new_module is not child_module:
setattr(module, child_name, new_module)
log_replacement(qual_name, child_module, new_module)
_recursive_replace(self.model, prefix="model")
def create_attention_instances(
self, attn_type: AttentionType = AttentionType.DECODER
) -> dict[int, Attention]:
"""
Create `Attention` instances to inform KV cache allocation.
"""
num_heads = self.model_config.get_num_attention_heads(self.parallel_config)
head_size = self.model_config.get_head_size()
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
logits_soft_cap = getattr(self.text_config, "attn_logit_softcapping", None)
start, end = get_pp_indices(
self.text_config.num_hidden_layers,
self.pp_group.rank_in_group,
self.pp_group.world_size,
)
attention_instances = {}
for i in range(start, end):
# Handle interleaved sliding window attention
per_layer_sliding_window = None
if (
hasattr(self.config, "layer_types")
and self.config.layer_types[i] == "sliding_attention"
):
per_layer_sliding_window = self.config.sliding_window
attention_instances[i] = Attention(
num_heads=num_heads,
head_size=head_size,
# NOTE: We use Llama scale as default, if it's set by
# Transformers, it's updated in vllm_flash_attention_forward
scale=head_size**-0.5,
num_kv_heads=num_kv_heads,
cache_config=self.cache_config,
quant_config=self.quant_config,
logits_soft_cap=logits_soft_cap,
per_layer_sliding_window=per_layer_sliding_window,
prefix=f"{i}.attn",
attn_type=attn_type,
)
return attention_instances
def init_parameters(self, module: nn.Module, dtype: torch.dtype | None = None):
"""
If a `parameter` is on the `meta` device, then its parent
`module` is the original module created by:
```python
with torch.device("meta"):
self.model: PreTrainedModel = AutoModel.from_config(...)
```
"""
def _init_parameters(module: nn.Module, dtype: torch.dtype | None):
for name, param in module.named_parameters(recurse=False):
if param.device == torch.device("meta"):
new_param = nn.Parameter(
torch.empty_like(
param.data,
dtype=dtype or self.model_config.dtype,
device=self.device_config.device,
)
)
setattr(module, name, new_param)
for child in module.children():
_init_parameters(child, dtype)
_init_parameters(module, dtype)
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor | IntermediateTensors:
if not self.pp_group.is_first_rank:
assert intermediate_tensors is not None
input_ids = None
inputs_embeds = intermediate_tensors["hidden_states"]
if input_ids is not None:
input_ids = input_ids[None, ...]
if inputs_embeds is not None:
inputs_embeds = inputs_embeds[None, ...]
if self.model_config.uses_mrope:
position_ids = positions[:, None]
else:
position_ids = positions[None, ...]
hidden_states = self.model(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
use_cache=False,
position_ids=position_ids,
attention_instances=self.attention_instances,
return_dict=False,
**kwargs,
)[0][0, ...] # we remove batch dimension for now
if not self.pp_group.is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
return hidden_states
def load_weights(
self,
weights: Iterable[tuple[str, torch.Tensor]],
) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=self.skip_prefixes,
skip_substrs=self.skip_substrs,
ignore_unexpected_prefixes=self.ignore_unexpected_prefixes,
ignore_unexpected_suffixes=self.ignore_unexpected_suffixes,
)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def check_version(self, min_version: str, feature: str):
installed = Version(transformers.__version__)
required = Version(min_version)
if installed < required:
raise ImportError(
f"Transformers backend requires transformers>={required} "
f"for {feature}, but got {installed}"
)
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersForCausalLM(TransformersBase):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
# Tell `TransformersBase.load_weights` to skip
# `lm_head` if the model has tied word embeddings
if self.text_config.tie_word_embeddings:
self.skip_prefixes.append("lm_head.")
if self.pp_group.is_last_rank:
self.unpadded_vocab_size = self.text_config.vocab_size
self.lm_head = ParallelLMHead(
self.text_config.vocab_size,
self.text_config.hidden_size,
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
if self.text_config.tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(
self.model.get_input_embeddings()
)
logit_scale = getattr(self.text_config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(
self.unpadded_vocab_size, self.text_config.vocab_size, logit_scale
)
else:
self.lm_head = PPMissingLayer()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
inputs_embeds = self.model.get_input_embeddings()(input_ids)
if self.embed_scale is not None:
inputs_embeds *= self.embed_scale
return inputs_embeds
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
@MULTIMODAL_REGISTRY.register_processor(
MultiModalProcessor,
info=MultiModalProcessingInfo,
dummy_inputs=MultiModalDummyInputsBuilder,
)
@support_torch_compile(
# set `positions` to last dim to support Qwen-mrope
dynamic_arg_dims={
"input_ids": 0,
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
},
enable_if=can_enable_torch_compile,
)
class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal):
supports_multimodal_raw_input_only = True
merge_by_field_config = True
# Backwards compatibility for prev released models. State dicts back then
# had different formats and cannot be loaded with `AutoModel` mapping as is
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"language_model.model": "model.language_model",
"text_model.model": "model.text_model",
"vision_tower": "model.vision_tower",
"vqmodel": "model.vqmodel",
"visual": "model.visual",
"vision_model": "model.vision_model",
"vision_embed_tokens": "model.vision_embed_tokens",
"image_newline": "model.image_newline",
"multi_modal_projector": "model.multi_modal_projector",
"text_model.lm_head": "lm_head",
"language_model.lm_head": "lm_head",
# Qwen models used "model" as the name for the language model.
# Therefore, we must map each of submodule explicitly to avoid
# conflicts with newer models that use "model.language_model".
"model.embed_tokens": "model.language_model.embed_tokens",
"model.layers": "model.language_model.layers",
"model.norm": "model.language_model.norm",
}
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
self.dtype = vllm_config.model_config.dtype
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
) -> torch.Tensor | IntermediateTensors:
# Gemma3 and PaliGemma needs `token_type_ids` to work correctly
# Other models will not have `token_type_ids` in kwargs
kwargs = {k: v for k, v in kwargs.items() if k == "token_type_ids"}
model_output = super().forward(
input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs
)
return model_output
def get_language_model(self) -> torch.nn.Module:
"""`TransformersForMultimodalLM` does not contain a vLLM language model class.
Therefore, in order to return a language model vLLM class, we use a wrapper to
give `self` the same interface as `TransformersForCausalLM`."""
class LanguageModelWrapper(TransformersForCausalLM):
def __init__(self, multimodal_model):
# Don't call super().__init__() to avoid re-initialization
self.__dict__.update(multimodal_model.__dict__)
model = getattr_iter(self.model, ("language_model", "text_model"), None)
return LanguageModelWrapper(self)
def get_multimodal_embeddings(self, **kwargs):
pixel_values: torch.Tensor | None = kwargs.pop("pixel_values", None)
image_embeds: torch.Tensor | None = kwargs.pop("image_embeds", None)
# Model might use `image_patches` instead of `pixel_values`
if pixel_values is None:
pixel_values = kwargs.pop("image_patches", None)
if image_embeds is not None:
return image_embeds
if pixel_values is None:
return None
num_image_patches = kwargs.pop("num_image_patches")
kwargs.pop("token_type_ids", None) # used only in `forward`
if pixel_values is not None:
vision_embeddings = self.model.get_image_features(pixel_values, **kwargs)
if isinstance(vision_embeddings, torch.Tensor):
if vision_embeddings.ndim == 2:
vision_embeddings = vision_embeddings.unsqueeze(0)
# Embeddings have to be 2D tensors of length `num_images`
# but transformers returns concat tensors if each patch
# is of different size. We split it back to make vLLM happy
vision_embeddings = torch.split(
vision_embeddings, num_image_patches.flatten().tolist()
)
vision_embeddings = [
embed.flatten(start_dim=0, end_dim=-2)
for embed in vision_embeddings
]
return vision_embeddings
get_input_embeddings = SupportsMultiModal.get_input_embeddings

View File

@ -0,0 +1,127 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2024 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wrapper around `transformers` models"""
from vllm.compilation.decorators import support_torch_compile
from vllm.model_executor.models.transformers.base import Base
from vllm.model_executor.models.transformers.causal import CausalMixin
from vllm.model_executor.models.transformers.legacy import LegacyMixin
from vllm.model_executor.models.transformers.moe import MoEMixin
from vllm.model_executor.models.transformers.multimodal import (
DYNAMIC_ARG_DIMS,
MultiModalDummyInputsBuilder,
MultiModalMixin,
MultiModalProcessingInfo,
MultiModalProcessor,
)
from vllm.model_executor.models.transformers.pooling import (
EmbeddingMixin,
SequenceClassificationMixin,
)
from vllm.model_executor.models.transformers.utils import can_enable_torch_compile
from vllm.multimodal import MULTIMODAL_REGISTRY
# Text only models
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersForCausalLM(CausalMixin, Base): ...
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersMoEForCausalLM(MoEMixin, CausalMixin, Base): ...
# Multimodal models
@MULTIMODAL_REGISTRY.register_processor(
MultiModalProcessor,
info=MultiModalProcessingInfo,
dummy_inputs=MultiModalDummyInputsBuilder,
)
@support_torch_compile(
dynamic_arg_dims=DYNAMIC_ARG_DIMS, enable_if=can_enable_torch_compile
)
class TransformersMultiModalForCausalLM(MultiModalMixin, CausalMixin, Base): ...
@MULTIMODAL_REGISTRY.register_processor(
MultiModalProcessor,
info=MultiModalProcessingInfo,
dummy_inputs=MultiModalDummyInputsBuilder,
)
@support_torch_compile(
dynamic_arg_dims=DYNAMIC_ARG_DIMS, enable_if=can_enable_torch_compile
)
class TransformersMultiModalMoEForCausalLM(
MoEMixin, MultiModalMixin, CausalMixin, Base
): ...
# Embedding models
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersEmbeddingModel(EmbeddingMixin, LegacyMixin, Base): ...
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersMoEEmbeddingModel(EmbeddingMixin, MoEMixin, Base): ...
@MULTIMODAL_REGISTRY.register_processor(
MultiModalProcessor,
info=MultiModalProcessingInfo,
dummy_inputs=MultiModalDummyInputsBuilder,
)
@support_torch_compile(
dynamic_arg_dims=DYNAMIC_ARG_DIMS, enable_if=can_enable_torch_compile
)
class TransformersMultiModalEmbeddingModel(EmbeddingMixin, MultiModalMixin, Base): ...
# Sequence classification models
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersForSequenceClassification(
SequenceClassificationMixin, LegacyMixin, Base
): ...
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersMoEForSequenceClassification(
SequenceClassificationMixin, MoEMixin, Base
): ...
@MULTIMODAL_REGISTRY.register_processor(
MultiModalProcessor,
info=MultiModalProcessingInfo,
dummy_inputs=MultiModalDummyInputsBuilder,
)
@support_torch_compile(
dynamic_arg_dims=DYNAMIC_ARG_DIMS, enable_if=can_enable_torch_compile
)
class TransformersMultiModalForSequenceClassification(
SequenceClassificationMixin, MultiModalMixin, Base
): ...
def __getattr__(name: str):
"""Handle imports of non-existent classes with a helpful error message."""
if name not in globals():
raise AttributeError(
"The Transformers backend does not currently have a class to handle "
f"the requested model type: {name}. Please open an issue at "
"https://github.com/vllm-project/vllm/issues/new"
)
return globals()[name]

View File

@ -0,0 +1,435 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2024 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformers backend base class."""
from collections.abc import Iterable
from typing import TYPE_CHECKING
import regex as re
import torch
import transformers
from packaging.version import Version
from torch import nn
from transformers import AutoModel
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from vllm.attention import Attention, AttentionType
from vllm.config.utils import getattr_iter
from vllm.distributed import get_pp_group, get_tp_group
from vllm.distributed.utils import get_pp_indices
from vllm.logger import init_logger
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.models.interfaces import (
SupportsLoRA,
SupportsPP,
SupportsQuant,
)
from vllm.model_executor.models.interfaces_base import VllmModel
from vllm.model_executor.models.transformers.utils import (
get_feature_request_tip,
init_on_device_without_buffers,
log_replacement,
replace_linear_class,
replace_rms_norm_class,
)
from vllm.model_executor.models.utils import (
AutoWeightsLoader,
PPMissingLayer,
make_empty_intermediate_tensors_factory,
maybe_prefix,
)
from vllm.sequence import IntermediateTensors
if TYPE_CHECKING:
from transformers import PreTrainedModel
from vllm.config import VllmConfig
else:
PreTrainedModel = object
logger = init_logger(__name__)
def vllm_flash_attention_forward(
# Transformers args
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor,
# Transformers kwargs
scaling: float | None = None,
# vLLM kwargs
attention_instances: dict[int, Attention] | None = None,
**kwargs,
):
self_attn = attention_instances[module.layer_idx]
if scaling is not None:
self_attn.impl.scale = float(scaling)
hidden = query.shape[-2]
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
query, key, value = (x.reshape(hidden, -1) for x in (query, key, value))
return self_attn.forward(query, key, value), None
ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward
class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP):
embedding_padding_modules = ["lm_head"]
embedding_modules = ["embed_tokens"] # TODO transformers will have a util to get it
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""):
super().__init__()
logger.info("Using Transformers backend.")
self.config = vllm_config.model_config.hf_config
self.text_config = self.config.get_text_config()
self.cache_config = vllm_config.cache_config
self.device_config = vllm_config.device_config
self.model_config = vllm_config.model_config
self.parallel_config = vllm_config.parallel_config
self.quant_config = vllm_config.quant_config
self.pp_group = get_pp_group()
self.tp_group = get_tp_group()
# Weights to skip in `self.load_weights`
self.skip_prefixes: list[str] = []
"""Skip loading weights whose qualname starts with these prefixes."""
self.skip_substrs: list[str] = []
"""Skip loading weights whose qualname contains these substrings."""
self.ignore_unexpected_prefixes: list[str] = []
"""Ignore unexpected weights whose qualname starts with these prefixes.
"""
self.ignore_unexpected_suffixes: list[str] = []
"""Ignore unexpected weights whose qualname ends with these suffixes."""
if self.quant_config:
quant_method_name = self.quant_config.get_name()
# Check for unsupported quantization methods.
if quant_method_name == "mxfp4":
raise NotImplementedError(
"Transformers backend does not support MXFP4 quantization yet."
)
# Skip loading extra bias for GPTQ models.
if "gptq" in quant_method_name:
self.ignore_unexpected_suffixes.append(".bias")
# Set correct attn and init on "meta" to delay allocating GPU tensors
self.text_config._attn_implementation = "vllm"
with init_on_device_without_buffers("meta"):
self.model: PreTrainedModel = AutoModel.from_config(
self.config,
dtype=self.model_config.dtype,
trust_remote_code=self.model_config.trust_remote_code,
)
# Remove layers not on this pipeline parallel rank
self.pipeline_parallel()
# Substitute remaining layers with vLLM's layers as needed
self.recursive_replace()
# Create attention instances for KV cache allocation
self.attention_instances = self.create_attention_instances()
# Input embeddings
input_embeddings = self.model.get_input_embeddings()
if not isinstance(input_embeddings, PPMissingLayer):
# Some models scale embeddings inside the input embedding layer
self.embed_scale = getattr(input_embeddings, "embed_scale", None)
names = ("embedding_size", "hidden_size")
embedding_dim = getattr_iter(self.text_config, names, None)
assert embedding_dim is not None
self.model.set_input_embeddings(
VocabParallelEmbedding(
self.text_config.vocab_size,
embedding_dim=embedding_dim,
org_num_embeddings=self.text_config.vocab_size,
quant_config=self.quant_config,
)
)
# Initialize any parameters that have not had their modules replaced
self.init_parameters(self.model)
# Pipeline parallel intermediate tensors
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states"], self.text_config.hidden_size
)
def pipeline_parallel(self):
"""
Apply the model's pipeline parallelization plan.
"""
if self.pp_group.world_size <= 1:
return
if not self.model.supports_pp_plan:
tip = get_feature_request_tip(
self.model_config.model, self.model_config.trust_remote_code
)
raise ValueError(
f"{type(self.model)} does not support pipeline parallel. {tip}"
)
module_lists = []
module_list_idx = None
pp_plan = list(self.model._pp_plan.keys())
for i, name in enumerate(pp_plan):
if isinstance(getattr(self.model, name), nn.ModuleList):
module_lists.append(name)
module_list_idx = i
if len(module_lists) > 1:
raise ValueError(
"Pipeline parallel of models with multiple `ModuleList`s "
"in the base model are not supported yet!"
)
if module_list_idx is None:
raise ValueError(f"Could not find `ModuleList` in {type(self.model)}")
# Layers before module list
for name in pp_plan[:module_list_idx]:
if self.pp_group.is_first_rank or (
self.text_config.tie_word_embeddings and self.pp_group.is_last_rank
):
continue
setattr(self.model, name, PPMissingLayer())
# Module list
start_layer, end_layer = get_pp_indices(
self.text_config.num_hidden_layers,
self.pp_group.rank_in_group,
self.pp_group.world_size,
)
layers_name = pp_plan[module_list_idx]
layers = getattr(self.model, layers_name)
for i in range(len(layers)):
if start_layer <= i and i < end_layer:
continue
layers[i] = PPMissingLayer()
# Layers after module list
for name in pp_plan[module_list_idx + 1 :]:
# Modules that should be on last rank
if not self.pp_group.is_last_rank:
setattr(self.model, name, PPMissingLayer())
def recursive_replace(self):
"""Recursively replace modules in the model as needed.
Currently, this replaces:
- `nn.Linear` with vLLM's tensor parallel linear classes
- `*RMSNorm` with vLLM's `RMSNorm`
"""
tp_plan = self.model.tp_plan
if not tp_plan and self.tp_group.world_size > 1:
tip = get_feature_request_tip(
self.model_config.model, self.model_config.trust_remote_code
)
raise ValueError(
f"{type(self.model)} does not support tensor parallel. {tip}"
)
# Prefix the patterns because we always start from `self.model`
tp_plan = {maybe_prefix("model", k): v for k, v in tp_plan.items()}
def _recursive_replace(module: nn.Module, prefix: str):
for child_name, child_module in module.named_children():
new_module = child_module
qual_name = maybe_prefix(prefix, child_name)
if isinstance(child_module, nn.Linear):
generator = (p for p in tp_plan if re.match(p, qual_name))
pattern = next(generator, None)
# Some weight loaders expect all linear layers to inherit
# LinearBase, so we set a default style which causes any
# unspecified layers to be replaced with ReplicatedLinear
style = tp_plan.get(pattern, "replicate")
new_module = replace_linear_class(
child_module, style, self.quant_config, prefix=qual_name
)
elif child_module.__class__.__name__.endswith("RMSNorm"):
new_module = replace_rms_norm_class(
child_module, self.text_config.hidden_size
)
else:
_recursive_replace(child_module, prefix=qual_name)
if new_module is not child_module:
setattr(module, child_name, new_module)
log_replacement(qual_name, child_module, new_module)
_recursive_replace(self.model, prefix="model")
def create_attention_instances(self) -> dict[int, Attention]:
"""
Create `Attention` instances to inform KV cache allocation.
"""
text_config = self.text_config
num_heads = self.model_config.get_num_attention_heads(self.parallel_config)
head_size = self.model_config.get_head_size()
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
logits_soft_cap = getattr(text_config, "attn_logit_softcapping", None)
# In encoder models, the attention layers will have `is_causal=False`
is_encoder = lambda module: not getattr(module, "is_causal", True)
has_encoder = lambda model: any(is_encoder(m) for m in model.modules())
is_multimodal = lambda config: config != config.get_text_config()
# vLLM does not support encoder-decoder models, so if any encoder layer is
# found in a text only model, we assume the whole model is an encoder model
if has_encoder(self.model) and not is_multimodal(self.config):
self.check_version("4.57.0.dev0", "encoder models support")
attn_type = AttentionType.ENCODER_ONLY
else:
attn_type = AttentionType.DECODER
pp_rank = self.pp_group.rank_in_group
pp_size = self.pp_group.world_size
start, end = get_pp_indices(text_config.num_hidden_layers, pp_rank, pp_size)
attention_instances = {}
for i in range(start, end):
# Handle interleaved sliding window attention
per_layer_sliding_window = None
if (
hasattr(self.config, "layer_types")
and self.config.layer_types[i] == "sliding_attention"
):
per_layer_sliding_window = self.config.sliding_window
attention_instances[i] = Attention(
num_heads=num_heads,
head_size=head_size,
# NOTE: We use Llama scale as default, if it's set by
# Transformers, it's updated in vllm_flash_attention_forward
scale=head_size**-0.5,
num_kv_heads=num_kv_heads,
cache_config=self.cache_config,
quant_config=self.quant_config,
logits_soft_cap=logits_soft_cap,
per_layer_sliding_window=per_layer_sliding_window,
prefix=f"{i}.attn",
attn_type=attn_type,
)
return attention_instances
def init_parameters(self, module: nn.Module, dtype: torch.dtype | None = None):
"""
If a `parameter` is on the `meta` device, then its parent
`module` is the original module created by:
```python
with torch.device("meta"):
self.model: "PreTrainedModel" = AutoModel.from_config(...)
```
"""
def _init_parameters(module: nn.Module, dtype: torch.dtype | None):
for name, param in module.named_parameters(recurse=False):
if param.device == torch.device("meta"):
new_param = nn.Parameter(
torch.empty_like(
param.data,
dtype=dtype or self.model_config.dtype,
device=self.device_config.device,
)
)
setattr(module, name, new_param)
for child in module.children():
_init_parameters(child, dtype)
_init_parameters(module, dtype)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
inputs_embeds = self.model.get_input_embeddings()(input_ids)
if self.embed_scale is not None:
inputs_embeds *= self.embed_scale
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor | IntermediateTensors:
if not self.pp_group.is_first_rank:
assert intermediate_tensors is not None
input_ids = None
inputs_embeds = intermediate_tensors["hidden_states"]
if input_ids is not None:
input_ids = input_ids[None, ...]
if inputs_embeds is not None:
inputs_embeds = inputs_embeds[None, ...]
# If the model scales embeddings inside the input embedding layer we must
# ensure they are scaled here since VocabParallelEmbedding will not do it
if (
self.embed_scale is not None
and input_ids is not None
and inputs_embeds is None
):
inputs_embeds = self.get_input_embeddings(input_ids)
input_ids = None
if self.model_config.uses_mrope:
position_ids = positions[:, None]
else:
position_ids = positions[None, ...]
hidden_states = self.model(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
use_cache=False,
position_ids=position_ids,
attention_instances=self.attention_instances,
return_dict=False,
**kwargs,
)[0][0, ...] # we remove batch dimension for now
if not self.pp_group.is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
return hidden_states
def load_weights(
self,
weights: Iterable[tuple[str, torch.Tensor]],
) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=self.skip_prefixes,
skip_substrs=self.skip_substrs,
ignore_unexpected_prefixes=self.ignore_unexpected_prefixes,
ignore_unexpected_suffixes=self.ignore_unexpected_suffixes,
)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
@staticmethod
def check_version(min_version: str, feature: str):
installed = Version(transformers.__version__)
required = Version(min_version)
if installed < required:
raise ImportError(
f"Transformers backend requires transformers>={required} "
f"for {feature}, but got {installed}"
)

View File

@ -0,0 +1,66 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2024 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformers backend mixin for causal language models."""
from typing import TYPE_CHECKING
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.models.interfaces_base import VllmModelForTextGeneration
from vllm.model_executor.models.utils import PPMissingLayer, maybe_prefix
if TYPE_CHECKING:
import torch
from vllm.config import VllmConfig
class CausalMixin(VllmModelForTextGeneration):
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""):
# Skip VllmModelForTextGeneration.__init__ and call the next class in MRO
super(VllmModelForTextGeneration, self).__init__(
vllm_config=vllm_config, prefix=prefix
)
# Tell `Base.load_weights` to skip
# `lm_head` if the model has tied word embeddings
if self.text_config.tie_word_embeddings:
self.skip_prefixes.append("lm_head.")
if self.pp_group.is_last_rank:
self.unpadded_vocab_size = self.text_config.vocab_size
self.lm_head = ParallelLMHead(
self.text_config.vocab_size,
self.text_config.hidden_size,
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
if self.text_config.tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(
self.model.get_input_embeddings()
)
logit_scale = getattr(self.text_config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(
self.unpadded_vocab_size, self.text_config.vocab_size, logit_scale
)
else:
self.lm_head = PPMissingLayer()
def compute_logits(self, hidden_states: "torch.Tensor") -> "torch.Tensor | None":
logits = self.logits_processor(self.lm_head, hidden_states)
return logits

View File

@ -0,0 +1,97 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2024 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformers backend mixin for legacy models."""
from typing import TYPE_CHECKING
import torch
from vllm.model_executor.models.utils import WeightsMapper
from vllm.sequence import IntermediateTensors
if TYPE_CHECKING:
from vllm.config import VllmConfig
class LegacyMixin:
hf_to_vllm_mapper = WeightsMapper(
# These are applied in order, so the order matters!
orig_to_new_prefix={
# Handle BERT-like models
"roberta": "model",
"bert": "model",
# Add `model.` prefix for base model checkpoints
"": "model.",
# Remove `model.` prefix if it was already there
"model.model.": "model.",
# Classifier/scoring heads will be adjacent to `model`
"model.score": "classifier",
"model.classifier": "classifier",
},
orig_to_new_suffix={
# Replace legacy suffixes used for norms
".gamma": ".weight",
".beta": ".bias",
},
)
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
# Skip unsupported/unwanted output embeddings layers
self.skip_prefixes.extend(
[
"model.lm_head.",
"model.predictions.",
"model.qa_outputs.",
"model.embeddings_project.",
"model.discriminator_predictions.",
]
)
# Some encoder models have the position_ids buffer in the checkpoint.
# vLLM will always pass position_ids as an argument, so we skip loading
# the buffer if it exists
self.skip_substrs.append("position_ids")
# Some encoder models have the bias of the final classifier layer
# in the checkpoint. vLLM does not use this bias, so we skip loading
# it if it exists
self.skip_substrs.append("score.bias")
# roberta-like models an extra padding in positions.
# FIXME(Isotr0py): This is quite hacky for roberta edge case,
# we should find a better way to handle this.
self.is_roberta = "roberta" in self.text_config.model_type
self.padding_idx = self.text_config.pad_token_id
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors:
if self.is_roberta:
# RoBERTa-specific positions padding
positions += self.padding_idx + 1
return super().forward(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)

View File

@ -14,31 +14,27 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wrapper around `transformers` MoE models."""
"""Transformers backend mixin for Mixture of Experts (MoE) models."""
from typing import Any
from typing import TYPE_CHECKING, Any
import torch
import torch.nn as nn
from vllm.compilation.decorators import support_torch_compile
from vllm.config.utils import getattr_iter
from vllm.distributed import get_dp_group, get_ep_group
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.models.interfaces import MixtureOfExperts
from vllm.model_executor.models.utils import maybe_prefix
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from .interfaces import MixtureOfExperts, SupportsMultiModal
from .transformers import (
TransformersBase,
TransformersForCausalLM,
TransformersForMultimodalLM,
can_enable_torch_compile,
log_replacement,
)
from .utils import maybe_prefix
from .utils import log_replacement
if TYPE_CHECKING:
from vllm.config import VllmConfig
@CustomOp.register("transformers_fused_moe")
@ -117,11 +113,11 @@ direct_register_custom_op(
)
class TransformersMoEBase(TransformersBase, MixtureOfExperts):
def __init__(self, *, vllm_config, prefix=""):
class MoEMixin(MixtureOfExperts):
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""):
self.check_version("4.57.0.dev0", "MoE models support")
self.ep_group = get_ep_group()
super().__init__(vllm_config=vllm_config, prefix=prefix)
# Skip MixtureOfExperts.__init__ and call the next class in MRO
super(MixtureOfExperts, self).__init__(vllm_config=vllm_config, prefix=prefix)
def set_eplb_state(
self,
@ -242,7 +238,7 @@ class TransformersMoEBase(TransformersBase, MixtureOfExperts):
num_redundant_experts = self.parallel_config.eplb_config.num_redundant_experts
# MixtureOfExperts mixin settings
ep_size = self.ep_group.world_size
ep_size = get_ep_group().world_size
self.mlp_layers = [] # Used for MixtureOfExperts methods
self.expert_weights = []
@ -316,24 +312,5 @@ class TransformersMoEBase(TransformersBase, MixtureOfExperts):
_recursive_replace(child_module, prefix=qual_name)
_recursive_replace(self.model, prefix="model")
# Continue with the replacement of layers in TransformersBase
# Continue with the replacement of layers in Base
super().recursive_replace()
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersMoEForCausalLM(TransformersMoEBase, TransformersForCausalLM):
pass
@support_torch_compile(
# set `positions` to last dim to support Qwen-mrope
dynamic_arg_dims={
"input_ids": 0,
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
},
enable_if=can_enable_torch_compile,
)
class TransformersMoEForMultimodalLM(TransformersMoEBase, TransformersForMultimodalLM):
get_input_embeddings = SupportsMultiModal.get_input_embeddings

View File

@ -0,0 +1,396 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2024 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformers backend mixin for multi-modal models."""
from collections.abc import Mapping
from typing import TYPE_CHECKING
import torch
from vllm.config.utils import getattr_iter
from vllm.model_executor.models.interfaces import SupportsMRoPE, SupportsMultiModal
from vllm.model_executor.models.utils import WeightsMapper
from vllm.multimodal import MultiModalKwargsItems
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalInputs,
MultiModalUUIDDict,
PlaceholderRange,
)
from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems
from vllm.multimodal.processing import BaseMultiModalProcessor, BaseProcessingInfo
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
if TYPE_CHECKING:
from transformers import BatchFeature, PretrainedConfig
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
DYNAMIC_ARG_DIMS = {
"input_ids": 0,
# set `positions` to last dim to support Qwen-mrope
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
}
class MultiModalProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self):
return {"image": None}
def get_mm_max_tokens_per_item(self, seq_len, mm_counts):
return {"image": self.get_max_image_tokens()}
def get_max_image_tokens(self) -> int:
width, height = self.get_max_image_size()
processor = self.get_hf_processor()
multimodal_config = self.ctx.model_config.multimodal_config
mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {}
mm_tokens = processor._get_num_multimodal_tokens(
image_sizes=([height, width],), **mm_processor_kwargs
)
image_tokens = mm_tokens["num_image_tokens"][0]
return image_tokens
def get_max_image_size(self):
return 10_000, 10_000 # hardcode for arbitrary very large size
class MultiModalDummyInputsBuilder(BaseDummyInputsBuilder[MultiModalProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
processor = self.info.get_hf_processor()
if "gemma3" in processor.__class__.__name__.lower():
image_token = processor.boi_token
else:
image_token = getattr(processor, "image_token", "")
return image_token * num_images
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, "BaseDummyOptions"] | None = None,
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
target_width, target_height = self.info.get_max_image_size()
image_overrides = mm_options.get("image") if mm_options else None
return {
"image": self._get_dummy_images(
width=target_width,
height=target_height,
num_images=num_images,
overrides=image_overrides,
),
}
class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargsItems,
):
"""
Given the original multi-modal items for this modality
and HF-processed data, output the updates to perform.
The information returned by this method is used to update token inputs
which bypass the HF processor. It is also used to update the output of
HF processor if the HF process does not apply prompt updates to text
inputs.
Moreover, this information is critical to determine the token positions
in order to construct :class:`~vllm-multimodal.input.PlaceholderRange`
for each multi-modal item.
"""
return None
def _get_mm_fields_config(
self,
hf_inputs: "BatchFeature",
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
# HF Processors always return a mask but vLLM doesn't need it
hf_inputs.pop("attention_mask", None)
num_image_patches = hf_inputs.get("num_image_patches")
mm_fields = {
key: MultiModalFieldConfig.flat_from_sizes("image", num_image_patches)
for key in hf_inputs
}
mm_fields["image_embeds"] = MultiModalFieldConfig.flat_from_sizes(
"image", num_image_patches
)
# Keep these as batched, as they always have batch size as first dim
mm_fields["image_grid_thw"] = MultiModalFieldConfig.batched("image")
mm_fields["video_grid_thw"] = MultiModalFieldConfig.batched("image")
mm_fields["num_image_patches"] = MultiModalFieldConfig.batched("image")
return mm_fields
def _get_hf_mm_data(
self,
mm_items: MultiModalDataItems,
) -> tuple[Mapping[str, object], Mapping[str, object]]:
"""
In contrast to the base class, this method always adds
`return_mm_token_type_ids` to the processor data
"""
processor_data, passthrough_data = super()._get_hf_mm_data(mm_items)
processor_data["return_mm_token_type_ids"] = True
return processor_data, passthrough_data
def apply(
self,
prompt: str | list[int],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object] | None = None,
mm_uuids: MultiModalUUIDDict | None = None,
) -> MultiModalInputs:
"""
Process multi-modal inputs to be used in vLLM.
Apply HF Processor on prompt text and multi-modal data together,
outputting token IDs and processed tensors.
"""
if tokenization_kwargs is None:
tokenization_kwargs = {}
mm_items = self._to_mm_items(mm_data)
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
if not isinstance(prompt, str):
# the prompt is the tokenized ids which is not supported
# by the hf_processor, which is why we would need to decode the ids
# into string
prompt = hf_processor.decode(prompt)
# Bypass cached processor and always apply to the full set of mm inputs
# NOTE: we can't just set caching=False because base class method
# transforms outputs to `MultiModalKwargs` which is not going to
# work for Transformers. We have a lot of logic tied to
# `mm_tokens_per_modality` below
prompt_ids, processed_data, _ = self._apply_hf_processor_text_mm(
prompt_text=prompt,
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
# For gemma3 we check `token_type_ids` as the key
token_type_key = (
"mm_token_type_ids"
if "mm_token_type_ids" in processed_data
else "token_type_ids"
)
mm_token_type_ids = processed_data.pop(token_type_key)
# We can infer vLLM style placeholder from token type ids, if we split
# it for each input `mm_data`.
mm_positions = torch.where(mm_token_type_ids == 1)[1]
images = mm_items.get_items("image", ImageProcessorItems)
multimodal_config = self.info.ctx.model_config.multimodal_config
mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {}
image_sizes = []
for item_idx in range(len(images)):
image_size = images.get_image_size(item_idx)
image_sizes.append((image_size.height, image_size.width))
mm_tokens_per_modality = hf_processor._get_num_multimodal_tokens(
image_sizes=image_sizes, **mm_processor_kwargs
)
mm_placeholders = {}
split_sizes = mm_tokens_per_modality["num_image_tokens"]
if split_sizes:
chunked_mm_positions = torch.split(mm_positions, split_sizes)
mm_tokens = torch.tensor(prompt_ids)[mm_token_type_ids[0].bool()]
chunked_mm_tokens = torch.split(mm_tokens, split_sizes)
ranges = [
PlaceholderRange(
offset=positions[0].item(),
length=positions.shape[0],
is_embed=(mm_tokens == hf_processor.image_token_id).bool(),
)
for positions, mm_tokens in zip(chunked_mm_positions, chunked_mm_tokens)
]
mm_placeholders = {"image": ranges}
processed_data["num_image_patches"] = torch.tensor(
mm_tokens_per_modality["num_image_patches"]
)
mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
processed_data,
self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
)
# Use overrides if provided; fallback to data-dependent hashing.
mm_hashes = self._hash_mm_items(
mm_items, hf_processor_mm_kwargs, tokenization_kwargs, mm_uuids=mm_uuids
)
return MultiModalInputs(
type="multimodal",
prompt_token_ids=prompt_ids,
mm_kwargs=mm_kwargs,
mm_hashes=mm_hashes,
mm_placeholders=mm_placeholders,
)
class MultiModalMixin(SupportsMultiModal, SupportsMRoPE):
supports_multimodal_raw_input_only = True
merge_by_field_config = True
# Backwards compatibility for prev released models. State dicts back then
# had different formats and cannot be loaded with `AutoModel` mapping as is
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"language_model.model": "model.language_model",
"text_model.model": "model.text_model",
"vision_tower": "model.vision_tower",
"vqmodel": "model.vqmodel",
"visual": "model.visual",
"vision_model": "model.vision_model",
"vision_embed_tokens": "model.vision_embed_tokens",
"image_newline": "model.image_newline",
"multi_modal_projector": "model.multi_modal_projector",
"text_model.lm_head": "lm_head",
"language_model.lm_head": "lm_head",
# Qwen models used "model" as the name for the language model.
# Therefore, we must map each of submodule explicitly to avoid
# conflicts with newer models that use "model.language_model".
"model.embed_tokens": "model.language_model.embed_tokens",
"model.layers": "model.language_model.layers",
"model.norm": "model.language_model.norm",
}
)
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""):
# Skip SupportsMRoPE.__init__ and call the next class in MRO
super(SupportsMRoPE, self).__init__(vllm_config=vllm_config, prefix=prefix)
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
) -> torch.Tensor | IntermediateTensors:
# Gemma3 and PaliGemma needs `token_type_ids` to work correctly
# Other models will not have `token_type_ids` in kwargs
kwargs = {k: v for k, v in kwargs.items() if k == "token_type_ids"}
model_output = super().forward(
input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs
)
return model_output
def get_language_model(self) -> torch.nn.Module:
"""Transformers backend multimodal classes do not contain a separate vLLM
language model class. Therefore, in order to return a language model vLLM class,
we use a wrapper to give `self` the same interface as a text model."""
# Exclude self and object
bases = self.__class__.mro()[1:-1]
# Keep only classes defined in `vllm.model_executor.models.transformers`
bases = [b for b in bases if ".transformers." in b.__module__]
# Exclude MultiModalMixin itself
bases = [b for b in bases if b is not MultiModalMixin]
class LanguageModel(*bases):
def __init__(self, multimodal_model):
# Don't call super().__init__() to avoid re-initialization
self.__dict__.update(multimodal_model.__dict__)
model = getattr_iter(self.model, ("language_model", "text_model"), None)
return LanguageModel(self)
def get_multimodal_embeddings(self, **kwargs):
pixel_values: torch.Tensor | None = kwargs.pop("pixel_values", None)
image_embeds: torch.Tensor | None = kwargs.pop("image_embeds", None)
# Model might use `image_patches` instead of `pixel_values`
if pixel_values is None:
pixel_values = kwargs.pop("image_patches", None)
if image_embeds is not None:
return image_embeds
if pixel_values is None:
return None
num_image_patches = kwargs.pop("num_image_patches")
kwargs.pop("token_type_ids", None) # used only in `forward`
if pixel_values is not None:
vision_embeddings = self.model.get_image_features(pixel_values, **kwargs)
if isinstance(vision_embeddings, torch.Tensor):
if vision_embeddings.ndim == 2:
vision_embeddings = vision_embeddings.unsqueeze(0)
# Embeddings have to be 2D tensors of length `num_images`
# but transformers returns concat tensors if each patch
# is of different size. We split it back to make vLLM happy
vision_embeddings = torch.split(
vision_embeddings, num_image_patches.flatten().tolist()
)
vision_embeddings = [
embed.flatten(start_dim=0, end_dim=-2)
for embed in vision_embeddings
]
return vision_embeddings
def get_mrope_input_positions(
self,
input_tokens: list[int],
hf_config: "PretrainedConfig",
image_grid_thw: list[list[int]] | torch.Tensor | None,
video_grid_thw: list[list[int]] | torch.Tensor | None,
second_per_grid_ts: list[float] | None = None,
context_len: int = 0,
seq_len: int | None = None,
audio_feature_lengths: torch.Tensor | None = None,
use_audio_in_video: bool = False,
) -> tuple[torch.Tensor, int]:
if any((second_per_grid_ts, audio_feature_lengths, use_audio_in_video)):
raise NotImplementedError("Transformers backend only supports images.")
if isinstance(image_grid_thw, list):
image_grid_thw = torch.tensor(image_grid_thw)
if isinstance(video_grid_thw, list):
video_grid_thw = torch.tensor(video_grid_thw)
mrope_positions, mrope_position_delta = self.model.get_rope_index(
input_ids=torch.tensor(input_tokens).unsqueeze(0),
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
)
mrope_positions = mrope_positions[:, 0, context_len:seq_len]
mrope_position_delta = mrope_position_delta[0].item()
return mrope_positions, mrope_position_delta

View File

@ -0,0 +1,118 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2024 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformers backend mixins for pooling models."""
from typing import TYPE_CHECKING
import torch
from transformers import AutoModelForSequenceClassification
from vllm.model_executor.layers.pooler import (
ClassifierPooler,
CLSPool,
DispatchPooler,
Pooler,
)
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
from vllm.model_executor.models.interfaces_base import VllmModelForPooling
if TYPE_CHECKING:
from vllm.config import VllmConfig
class EmbeddingMixin(VllmModelForPooling):
default_pooling_type = "CLS"
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""):
# Skip VllmModelForPooling.__init__ and call the next class in MRO
super(VllmModelForPooling, self).__init__(
vllm_config=vllm_config, prefix=prefix
)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler(
{
"token_embed": Pooler.for_token_embed(pooler_config),
"embed": Pooler.for_embed(pooler_config),
}
)
class SequenceClassificationMixin(SupportsCrossEncoding, VllmModelForPooling):
default_pooling_type = "CLS"
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""):
# Skip VllmModelForPooling.__init__ and call the next class in MRO
super(VllmModelForPooling, self).__init__(
vllm_config=vllm_config, prefix=prefix
)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
# Certain information about the the model and classifier can only be
# inferred from the `ForSequenceClassification` class. Therefore, we
# instantiate it on the "meta" device to avoid allocating GPU memory.
with torch.device("meta"):
seq_cls_model = AutoModelForSequenceClassification.from_config(
self.config,
dtype=self.model_config.dtype,
trust_remote_code=self.model_config.trust_remote_code,
)
# When used for sequence classification, some models have their
# pooling layers removed. Make sure this is reflected in vLLM.
for module in seq_cls_model.modules():
if hasattr(module, "pooler") and module.pooler is None:
self.model.pooler = None
break
if self.model.pooler is not None:
raise ValueError(
"Sequence classification models with pooling layers are not "
"supported yet in the Transformers backend."
)
# Unlike `lm_head`, `classifier` is not always `nn.Linear`.
self.classifier = seq_cls_model.classifier
self.init_parameters(self.classifier, dtype=self.model_config.head_dtype)
class ClassifierWithReshape(self.classifier.__class__):
"""CLSPool has already been applied in `pooling`.
Add dim to match expected input shape of `classifier.forward`."""
def forward(self, *args, **kwargs):
if len(args) > 0:
args = (args[0].unsqueeze(1), *args[1:])
return super().forward(*args, **kwargs)
self.classifier.__class__ = ClassifierWithReshape
self.pooler = DispatchPooler(
{
"token_classify": Pooler.for_token_classify(
pooler_config, classifier=self.classifier
),
"classify": ClassifierPooler(
pooling=CLSPool(), classifier=self.classifier, act_fn="classify"
),
"score": ClassifierPooler(
pooling=CLSPool(), classifier=self.classifier, act_fn="score"
),
}
)

View File

@ -0,0 +1,207 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2024 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformers backend utilities."""
from contextlib import contextmanager
from pathlib import Path
from typing import TYPE_CHECKING, Literal
import torch
from torch import nn
from vllm.config.utils import getattr_iter
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.model_executor.layers.quantization import QuantizationConfig
logger = init_logger(__name__)
# Copied from `accelerate`
@contextmanager
def init_on_device_without_buffers(device: torch.device):
"""
A context manager under which models are initialized with all
parameters on the specified device. However buffers are not
initialized on specified device.
Args:
device (`torch.device`):
Device to initialize all parameters on.
"""
old_register_parameter = nn.Module.register_parameter
def register_empty_parameter(module, name, param):
old_register_parameter(module, name, param)
if param is not None:
param_cls = type(module._parameters[name])
kwargs = module._parameters[name].__dict__
kwargs["requires_grad"] = param.requires_grad
module._parameters[name] = param_cls(
module._parameters[name].to(device), **kwargs
)
tensor_constructors_to_patch = {}
def patch_tensor_constructor(fn):
def wrapper(*args, **kwargs):
kwargs["device"] = device
return fn(*args, **kwargs)
return wrapper
try:
nn.Module.register_parameter = register_empty_parameter
for torch_function_name in tensor_constructors_to_patch:
setattr(
torch,
torch_function_name,
patch_tensor_constructor(getattr(torch, torch_function_name)),
)
yield
finally:
nn.Module.register_parameter = old_register_parameter
for (
torch_function_name,
old_torch_function,
) in tensor_constructors_to_patch.items():
setattr(torch, torch_function_name, old_torch_function)
Style = Literal["colwise", "colwise_rep", "rowwise", "rowwise_rep", "replicate"]
def replace_linear_class(
linear: nn.Linear,
style: Style = "replicate",
quant_config: "QuantizationConfig | None" = None,
*,
prefix: str = "",
) -> ColumnParallelLinear | RowParallelLinear | ReplicatedLinear:
"""
Replace nn.Linear with one of vLLM's tensor parallel linear classes.
Args:
linear: `nn.Linear` to be replaced.
style: Tensor parallel style of the new linear, e.g. "colwise".
quant_config: Quantization config for the new linear.
Returns:
The new linear.
"""
if not isinstance(style, str):
raise ValueError(f"Unsupported parallel style type {type(style)}, expected str")
vllm_linear_cls, vllm_linear_kwargs = {
"colwise": (ColumnParallelLinear, {}),
"colwise_rep": (ColumnParallelLinear, {"gather_output": True}),
"rowwise": (RowParallelLinear, {}),
"rowwise_rep": (RowParallelLinear, {"input_is_parallel": False}),
"replicate": (ReplicatedLinear, {}),
}.get(style, (ReplicatedLinear, {}))
return vllm_linear_cls(
input_size=linear.in_features,
output_size=linear.out_features,
bias=linear.bias is not None,
quant_config=quant_config,
prefix=prefix,
return_bias=False,
**vllm_linear_kwargs,
)
def replace_rms_norm_class(rms_norm: nn.Module, hidden_size: int) -> RMSNorm:
"""Replace a Transformers RMSNorm with vLLM's RMSNorm.
This method assumes:
- Weight is stored as `weight`.
- Epsilon is stored as `eps` or `variance_epsilon`.
- `with_scale` indicates whether the layer has a weight (Gemma3n only).
- `var_hidden_size` is only ever used for Intern vision encoder in vLLM
and Transformers doesn't appear to have the same concept.
"""
eps = getattr_iter(rms_norm, ("eps", "variance_epsilon"), 1e-6)
kwargs = {"hidden_size": hidden_size, "eps": eps}
# Update hidden size if weight is available
weight_meta = getattr(rms_norm, "weight", None)
if weight_meta is not None:
kwargs["hidden_size"] = weight_meta.size(0)
# Check if weight is all zeros, which indicates GemmaRMSNorm
# We must create a new instance because rms_norm is on meta
try:
with torch.device("cpu"):
weight_test = getattr(rms_norm.__class__(1), "weight", None)
except Exception:
logger.warning(
"Failed to determine if RMSNorm weight is centered on zero or one. "
"Defaulting to one."
)
weight_test = None
if weight_test is not None and torch.all(weight_test == 0):
return GemmaRMSNorm(**kwargs)
# Otherwise assume it's a regular RMSNorm
kwargs["has_weight"] = getattr(rms_norm, "with_scale", True)
if weight_meta is not None:
kwargs["dtype"] = weight_meta.dtype
else:
# No weight, fall back to weightless RMSNorm
kwargs["has_weight"] = False
return RMSNorm(**kwargs)
def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):
logger.debug("%s: %s -> %s", name, old_module, new_module)
def get_feature_request_tip(
model: str,
trust_remote_code: bool,
) -> str:
hf_url = f"a discussion at https://huggingface.co/{model}/discussions/new"
gh_url = "an issue at https://github.com/huggingface/transformers/issues/new/choose"
url = hf_url if trust_remote_code else gh_url
prefix = f"Please open {url} to request support for this feature. "
if Path(model).exists():
prefix = ""
doc_url = "https://docs.vllm.ai/en/latest/models/supported_models.html#writing-custom-models"
tip = f"See {doc_url} for instructions on how to add support yourself."
return f"{prefix}{tip}"
def can_enable_torch_compile(vllm_config: "VllmConfig") -> bool:
"""
Callable to be passed to `@support_torch_compile`'s `enable_if` argument.
Defaults to `True` but is disabled in the following situations:
- The model uses dynamic rope scaling.
"""
text_config = vllm_config.model_config.hf_config.get_text_config()
# Dynamic rope scaling is not compatible with torch.compile
rope_scaling: dict = getattr(text_config, "rope_scaling", None) or {}
return rope_scaling.get("rope_type") != "dynamic"

View File

@ -1,215 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2024 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wrapper around `transformers` models for pooling tasks."""
import torch
from transformers import AutoModelForSequenceClassification
from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import (
ClassifierPooler,
CLSPool,
DispatchPooler,
Pooler,
)
from vllm.sequence import IntermediateTensors
from .interfaces_base import VllmModelForPooling
from .transformers import TransformersBase, can_enable_torch_compile
from .transformers_moe import TransformersMoEBase
from .utils import WeightsMapper
class TransformersPoolingBase(TransformersBase, VllmModelForPooling):
hf_to_vllm_mapper = WeightsMapper(
# These are applied in order, so the order matters!
orig_to_new_prefix={
# Handle BERT-like models
"roberta": "model",
"bert": "model",
# Add `model.` prefix for base model checkpoints
"": "model.",
# Remove `model.` prefix if it was already there
"model.model.": "model.",
# Classifier/scoring heads will be adjacent to `model`
"model.score": "classifier",
"model.classifier": "classifier",
},
orig_to_new_suffix={
# Replace legacy suffixes used for norms
".gamma": ".weight",
".beta": ".bias",
},
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
# Skip unsupported/unwanted output embeddings layers
self.skip_prefixes.extend(
[
"model.lm_head.",
"model.predictions.",
"model.qa_outputs.",
"model.embeddings_project.",
"model.discriminator_predictions.",
]
)
# Some encoder models have the position_ids buffer in the checkpoint.
# vLLM will always pass position_ids as an argument, so we skip loading
# the buffer if it exists
self.skip_substrs.append("position_ids")
# Some encoder models have the bias of the final classifier layer
# in the checkpoint. vLLM does not use this bias, so we skip loading
# it if it exists
self.skip_substrs.append("score.bias")
# roberta-like models an extra padding in positions.
# FIXME(Isotr0py): This is quite hacky for roberta edge case,
# we should find a better way to handle this.
self.is_roberta = "roberta" in self.text_config.model_type
self.padding_idx = self.text_config.pad_token_id
def create_attention_instances(
self, attn_type: AttentionType = AttentionType.DECODER
) -> dict[int, Attention]:
# TODO(hmellor): Better way to detect encoder models
# In encoder models, the attention layers will have `is_causal=False`
is_encoder = lambda m: not getattr(m, "is_causal", True)
# vLLM does not support encoder-decoder models, so if any encoder layer
# is found, we assume the whole model is an encoder model
if any(is_encoder(m) for m in self.model.modules()):
attn_type = AttentionType.ENCODER_ONLY
# Check minimum transformers version for encoder models support
if attn_type == AttentionType.ENCODER_ONLY:
self.check_version("4.57.0.dev0", "encoder models support")
return super().create_attention_instances(attn_type)
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors:
if self.is_roberta:
# RoBERTa-specific positions padding
positions += self.padding_idx + 1
return super().forward(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersEmbeddingModel(TransformersPoolingBase):
default_pooling_type = "CLS"
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler(
{
"token_embed": Pooler.for_token_embed(pooler_config),
"embed": Pooler.for_embed(pooler_config),
}
)
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersForSequenceClassification(TransformersPoolingBase):
default_pooling_type = "CLS"
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
# Certain information about the the model and classifier can only be
# inferred from the `ForSequenceClassification` class. Therefore, we
# instantiate it on the "meta" device to avoid allocating GPU memory.
with torch.device("meta"):
seq_cls_model = AutoModelForSequenceClassification.from_config(
self.config,
dtype=self.model_config.dtype,
trust_remote_code=self.model_config.trust_remote_code,
)
# When used for sequence classification, some models have their
# pooling layers removed. Make sure this is reflected in vLLM.
for module in seq_cls_model.modules():
if hasattr(module, "pooler") and module.pooler is None:
self.model.pooler = None
break
if self.model.pooler is not None:
raise ValueError(
"Sequence classification models with pooling layers are not "
"supported yet in the Transformers backend."
)
# Unlike `lm_head`, `classifier` is not always `nn.Linear`.
self.classifier = seq_cls_model.classifier
self.init_parameters(self.classifier, dtype=self.model_config.head_dtype)
class ClassifierWithReshape(self.classifier.__class__):
"""CLSPool has already been applied in `pooling`.
Add dim to match expected input shape of `classifier.forward`."""
def forward(self, *args, **kwargs):
if len(args) > 0:
args = (args[0].unsqueeze(1), *args[1:])
return super().forward(*args, **kwargs)
self.classifier.__class__ = ClassifierWithReshape
self.pooler = DispatchPooler(
{
"token_classify": Pooler.for_token_classify(
pooler_config, classifier=self.classifier
),
"classify": ClassifierPooler(
pooling=CLSPool(), classifier=self.classifier, act_fn="classify"
),
"score": ClassifierPooler(
pooling=CLSPool(), classifier=self.classifier, act_fn="score"
),
}
)
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersMoEEmbeddingModel(TransformersMoEBase, TransformersEmbeddingModel):
pass
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersMoEForSequenceClassification(
TransformersMoEBase, TransformersForSequenceClassification
):
pass