From fb5e10d3fbb79323d5d9444543a143a864fe29cc Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 16 Oct 2025 22:50:39 +0100 Subject: [PATCH] Refactor Transformers backend to use mixins (#26906) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- .github/CODEOWNERS | 2 +- tests/models/registry.py | 8 +- tests/models/test_initialization.py | 2 +- tests/models/test_transformers.py | 6 +- vllm/config/model.py | 25 +- vllm/model_executor/models/deepseek_vl2.py | 2 +- vllm/model_executor/models/registry.py | 40 +- vllm/model_executor/models/transformers.py | 961 ------------------ .../models/transformers/__init__.py | 127 +++ .../models/transformers/base.py | 435 ++++++++ .../models/transformers/causal.py | 66 ++ .../models/transformers/legacy.py | 97 ++ .../moe.py} | 51 +- .../models/transformers/multimodal.py | 396 ++++++++ .../models/transformers/pooling.py | 118 +++ .../models/transformers/utils.py | 207 ++++ .../models/transformers_pooling.py | 215 ---- 17 files changed, 1510 insertions(+), 1248 deletions(-) delete mode 100644 vllm/model_executor/models/transformers.py create mode 100644 vllm/model_executor/models/transformers/__init__.py create mode 100644 vllm/model_executor/models/transformers/base.py create mode 100644 vllm/model_executor/models/transformers/causal.py create mode 100644 vllm/model_executor/models/transformers/legacy.py rename vllm/model_executor/models/{transformers_moe.py => transformers/moe.py} (90%) create mode 100644 vllm/model_executor/models/transformers/multimodal.py create mode 100644 vllm/model_executor/models/transformers/pooling.py create mode 100644 vllm/model_executor/models/transformers/utils.py delete mode 100644 vllm/model_executor/models/transformers_pooling.py diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 3fbc38d9a2..14301fe8d8 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -57,7 +57,7 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson /tests/v1/offloading @ApostaC # Transformers backend -/vllm/model_executor/models/transformers.py @hmellor +/vllm/model_executor/models/transformers @hmellor /tests/models/test_transformers.py @hmellor # Docs diff --git a/tests/models/registry.py b/tests/models/registry.py index 1d3d7fe659..6bcb12e9e2 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -912,11 +912,11 @@ _TRANSFORMERS_BACKEND_MODELS = { "TransformersForCausalLM": _HfExamplesInfo( "hmellor/Ilama-3.2-1B", trust_remote_code=True ), - "TransformersForMultimodalLM": _HfExamplesInfo("BAAI/Emu3-Chat-hf"), + "TransformersMultiModalForCausalLM": _HfExamplesInfo("BAAI/Emu3-Chat-hf"), "TransformersMoEForCausalLM": _HfExamplesInfo( "allenai/OLMoE-1B-7B-0924", min_transformers_version="4.57.0.dev0" ), - "TransformersMoEForMultimodalLM": _HfExamplesInfo( + "TransformersMultiModalMoEForCausalLM": _HfExamplesInfo( "Qwen/Qwen3-VL-30B-A3B-Instruct", min_transformers_version="4.57.0.dev0" ), "TransformersMoEEmbeddingModel": _HfExamplesInfo( @@ -925,6 +925,10 @@ _TRANSFORMERS_BACKEND_MODELS = { "TransformersMoEForSequenceClassification": _HfExamplesInfo( "Qwen/Qwen3-30B-A3B", min_transformers_version="4.57.0.dev0" ), + "TransformersMultiModalEmbeddingModel": _HfExamplesInfo("google/gemma-3-4b-it"), + "TransformersMultiModalForSequenceClassification": _HfExamplesInfo( + "google/gemma-3-4b-it" + ), } _EXAMPLE_MODELS = { diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index 80bee3d8cf..6d1f67c39f 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -37,7 +37,7 @@ MINIMAL_MODEL_ARCH_LIST = [ "JinaVLForRanking", "InternVLChatModel", "InternLM2ForRewardModel", - "TransformersForMultimodalLM", + "TransformersMultiModalForCausalLM", "PrithviGeoSpatialMAE", "UltravoxModel", "DeepSeekMTPModel", diff --git a/tests/models/test_transformers.py b/tests/models/test_transformers.py index f9e252a23b..d8a1aace83 100644 --- a/tests/models/test_transformers.py +++ b/tests/models/test_transformers.py @@ -211,11 +211,7 @@ def test_embed_loading(vllm_runner, model): def test_pooling(hf_runner, vllm_runner, example_prompts, arch): model = get_model(arch) - vllm_kwargs = dict( - max_model_len=None, - model_impl="transformers", - compilation_config=dict(cudagraph_capture_sizes=[8]), - ) + vllm_kwargs = dict(max_model_len=None, model_impl="transformers") hf_kwargs = dict() if arch == "TransformersEmbeddingModel": diff --git a/vllm/config/model.py b/vllm/config/model.py index b572967d36..6602f7c0a5 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -147,6 +147,10 @@ class ModelConfig: seed: int | None = None """Random seed for reproducibility. Initialized to None in V0, but initialized to 0 in V1.""" + hf_config: PretrainedConfig = field(init=False) + """The Hugging Face config of the model.""" + hf_text_config: PretrainedConfig = field(init=False) + """The Hugging Face config of the text model (same as hf_config for text models).""" hf_config_path: str | None = None """Name or path of the Hugging Face config to use. If unspecified, model name or path will be used.""" @@ -771,8 +775,10 @@ class ModelConfig: def _get_transformers_backend_cls(self) -> str: """Determine which Transformers backend class will be used if `model_impl` is set to `transformers` or `auto`.""" - prefix = "Transformers" - prefix += "MoE" if self.get_num_experts() > 1 else "" + cls = "Transformers" + # If 'hf_config != hf_text_config' it's a nested config, i.e. multimodal + cls += "MultiModal" if self.hf_config != self.hf_text_config else "" + cls += "MoE" if self.get_num_experts() > 1 else "" # Check if the architecture we're wrapping has defaults runner = None convert = None @@ -788,18 +794,15 @@ class ModelConfig: runner = "generate" if convert in {None, "none"}: convert = "embed" - # Resolve Transformers backend pooling classes + # Resolve Transformers backend task if runner == "pooling": if convert == "embed": - return prefix + "EmbeddingModel" + return cls + "EmbeddingModel" if convert == "classify": - return prefix + "ForSequenceClassification" - # Resolve Transformers backend generate classes - if self.hf_config != self.hf_text_config: - # If 'hf_text_config' is the same as 'hf_config'. If not, it is - # probably a composite config, i.e. multimodal - return prefix + "ForMultimodalLM" - return prefix + "ForCausalLM" + return cls + "ForSequenceClassification" + else: + cls += "ForCausalLM" + return cls def using_transformers_backend(self) -> bool: """Check if the model is using the Transformers backend class.""" diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index d9e1523b04..759f2a18d3 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -19,7 +19,7 @@ from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.utils import set_default_torch_dtype -from vllm.model_executor.models.transformers import replace_linear_class +from vllm.model_executor.models.transformers.utils import replace_linear_class from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( MultiModalDataDict, diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index d119c161f6..4171ebdbde 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -401,32 +401,44 @@ _TRANSFORMERS_SUPPORTED_MODELS = { # Text generation models "SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"), # Multimodal models - "Emu3ForConditionalGeneration": ("transformers", "TransformersForMultimodalLM"), # noqa: E501 + "Emu3ForConditionalGeneration": ( + "transformers", + "TransformersMultiModalForCausalLM", + ), } _TRANSFORMERS_BACKEND_MODELS = { + # Text generation models "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"), - "TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501 - "TransformersMoEForCausalLM": ("transformers_moe", "TransformersMoEForCausalLM"), # noqa: E501 - "TransformersMoEForMultimodalLM": ( - "transformers_moe", - "TransformersMoEForMultimodalLM", + "TransformersMoEForCausalLM": ("transformers", "TransformersMoEForCausalLM"), + # Multimodal models + "TransformersMultiModalForCausalLM": ( + "transformers", + "TransformersMultiModalForCausalLM", ), - "TransformersEmbeddingModel": ( - "transformers_pooling", - "TransformersEmbeddingModel", + "TransformersMultiModalMoEForCausalLM": ( + "transformers", + "TransformersMultiModalMoEForCausalLM", ), + # Embedding models + "TransformersEmbeddingModel": ("transformers", "TransformersEmbeddingModel"), + "TransformersMoEEmbeddingModel": ("transformers", "TransformersMoEEmbeddingModel"), + "TransformersMultiModalEmbeddingModel": ( + "transformers", + "TransformersMultiModalEmbeddingModel", + ), + # Sequence classification models "TransformersForSequenceClassification": ( - "transformers_pooling", + "transformers", "TransformersForSequenceClassification", ), "TransformersMoEForSequenceClassification": ( - "transformers_pooling", + "transformers", "TransformersMoEForSequenceClassification", ), - "TransformersMoEEmbeddingModel": ( - "transformers_pooling", - "TransformersMoEEmbeddingModel", + "TransformersMultiModalForSequenceClassification": ( + "transformers", + "TransformersMultiModalForSequenceClassification", ), } diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py deleted file mode 100644 index a8709ea426..0000000000 --- a/vllm/model_executor/models/transformers.py +++ /dev/null @@ -1,961 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# Copyright 2024 The vLLM team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Wrapper around `transformers` models""" - -from collections.abc import Iterable, Mapping -from contextlib import contextmanager -from pathlib import Path -from typing import Literal - -import regex as re -import torch -import transformers -from packaging.version import Version -from torch import nn -from transformers import AutoModel, BatchFeature, PretrainedConfig, PreTrainedModel -from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS - -from vllm.attention import Attention, AttentionType -from vllm.compilation.decorators import support_torch_compile -from vllm.config import ( - CacheConfig, - DeviceConfig, - ModelConfig, - ParallelConfig, - VllmConfig, -) -from vllm.config.multimodal import BaseDummyOptions -from vllm.config.utils import getattr_iter -from vllm.distributed import get_pp_group, get_tp_group -from vllm.distributed.utils import get_pp_indices -from vllm.logger import init_logger -from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm -from vllm.model_executor.layers.linear import ( - ColumnParallelLinear, - ReplicatedLinear, - RowParallelLinear, -) -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, - VocabParallelEmbedding, -) -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems -from vllm.multimodal.inputs import ( - MultiModalDataDict, - MultiModalFieldConfig, - MultiModalInputs, - MultiModalUUIDDict, - PlaceholderRange, -) -from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems -from vllm.multimodal.processing import BaseMultiModalProcessor, BaseProcessingInfo -from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.sequence import IntermediateTensors - -from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP, SupportsQuant -from .utils import ( - AutoWeightsLoader, - PPMissingLayer, - WeightsMapper, - make_empty_intermediate_tensors_factory, - maybe_prefix, -) - -logger = init_logger(__name__) - - -def get_feature_request_tip( - model: str, - trust_remote_code: bool, -) -> str: - hf_url = f"a discussion at https://huggingface.co/{model}/discussions/new" - gh_url = "an issue at https://github.com/huggingface/transformers/issues/new/choose" - url = hf_url if trust_remote_code else gh_url - prefix = f"Please open {url} to request support for this feature. " - if Path(model).exists(): - prefix = "" - doc_url = "https://docs.vllm.ai/en/latest/models/supported_models.html#writing-custom-models" - tip = f"See {doc_url} for instructions on how to add support yourself." - return f"{prefix}{tip}" - - -def vllm_flash_attention_forward( - # Transformers args - module: torch.nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: torch.Tensor, - # Transformers kwargs - scaling: float | None = None, - # vLLM kwargs - attention_instances: dict[Attention] | None = None, - **kwargs, -): - self_attn = attention_instances[module.layer_idx] - if scaling is not None: - self_attn.impl.scale = float(scaling) - hidden = query.shape[-2] - query, key, value = (x.transpose(1, 2) for x in (query, key, value)) - query, key, value = (x.reshape(hidden, -1) for x in (query, key, value)) - return self_attn.forward(query, key, value), None - - -ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward - - -def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module): - logger.debug("%s: %s -> %s", name, old_module, new_module) - - -def can_enable_torch_compile(vllm_config: VllmConfig) -> bool: - """ - Callable to be passed to `@support_torch_compile`'s `enable_if` argument. - - Defaults to `True` but is disabled in the following situations: - - - The model uses dynamic rope scaling. - """ - enable = True - text_config = vllm_config.model_config.hf_config.get_text_config() - # Dynamic rope scaling is not compatible with torch.compile - rope_scaling: dict = getattr(text_config, "rope_scaling", None) or {} - if rope_scaling.get("rope_type") == "dynamic": - enable = False - return enable - - -Style = Literal["colwise", "colwise_rep", "rowwise", "rowwise_rep", "replicate"] - - -def replace_linear_class( - linear: nn.Linear, - style: Style = "replicate", - quant_config: QuantizationConfig | None = None, - *, - prefix: str = "", -) -> ColumnParallelLinear | RowParallelLinear | ReplicatedLinear: - """ - Replace nn.Linear with one of vLLM's tensor parallel linear classes. - - Args: - linear: `nn.Linear` to be replaced. - style: Tensor parallel style of the new linear, e.g. "colwise". - quant_config: Quantization config for the new linear. - Returns: - The new linear. - """ - - if not isinstance(style, str): - raise ValueError(f"Unsupported parallel style type {type(style)}, expected str") - - vllm_linear_cls, vllm_linear_kwargs = { - "colwise": (ColumnParallelLinear, {}), - "colwise_rep": (ColumnParallelLinear, {"gather_output": True}), - "rowwise": (RowParallelLinear, {}), - "rowwise_rep": (RowParallelLinear, {"input_is_parallel": False}), - "replicate": (ReplicatedLinear, {}), - }.get(style, (ReplicatedLinear, {})) - - return vllm_linear_cls( - input_size=linear.in_features, - output_size=linear.out_features, - bias=linear.bias is not None, - quant_config=quant_config, - prefix=prefix, - return_bias=False, - **vllm_linear_kwargs, - ) - - -def replace_rms_norm_class(rms_norm: nn.Module, hidden_size: int) -> RMSNorm: - """Replace a Transformers RMSNorm with vLLM's RMSNorm. - - This method assumes: - - Weight is stored as `weight`. - - Epsilon is stored as `eps` or `variance_epsilon`. - - `with_scale` indicates whether the layer has a weight (Gemma3n only). - - `var_hidden_size` is only ever used for Intern vision encoder in vLLM - and Transformers doesn't appear to have the same concept. - """ - eps = getattr_iter(rms_norm, ("eps", "variance_epsilon"), 1e-6) - kwargs = {"hidden_size": hidden_size, "eps": eps} - # Update hidden size if weight is available - weight_meta = getattr(rms_norm, "weight", None) - if weight_meta is not None: - kwargs["hidden_size"] = weight_meta.size(0) - # Check if weight is all zeros, which indicates GemmaRMSNorm - # We must create a new instance because rms_norm is on meta - try: - with torch.device("cpu"): - weight_test = getattr(rms_norm.__class__(1), "weight", None) - except Exception: - logger.warning( - "Failed to determine if RMSNorm weight is centered on zero or one. " - "Defaulting to one." - ) - weight_test = None - if weight_test is not None and torch.all(weight_test == 0): - return GemmaRMSNorm(**kwargs) - # Otherwise assume it's a regular RMSNorm - kwargs["has_weight"] = getattr(rms_norm, "with_scale", True) - if weight_meta is not None: - kwargs["dtype"] = weight_meta.dtype - else: - # No weight, fall back to weightless RMSNorm - kwargs["has_weight"] = False - return RMSNorm(**kwargs) - - -# Copied from `accelerate` -@contextmanager -def init_on_device_without_buffers(device: torch.device): - """ - A context manager under which models are initialized with all - parameters on the specified device. However buffers are not - initialized on specified device. - - Args: - device (`torch.device`): - Device to initialize all parameters on. - """ - - old_register_parameter = nn.Module.register_parameter - - def register_empty_parameter(module, name, param): - old_register_parameter(module, name, param) - if param is not None: - param_cls = type(module._parameters[name]) - kwargs = module._parameters[name].__dict__ - kwargs["requires_grad"] = param.requires_grad - module._parameters[name] = param_cls( - module._parameters[name].to(device), **kwargs - ) - - tensor_constructors_to_patch = {} - - def patch_tensor_constructor(fn): - def wrapper(*args, **kwargs): - kwargs["device"] = device - return fn(*args, **kwargs) - - return wrapper - - try: - nn.Module.register_parameter = register_empty_parameter - for torch_function_name in tensor_constructors_to_patch: - setattr( - torch, - torch_function_name, - patch_tensor_constructor(getattr(torch, torch_function_name)), - ) - yield - finally: - nn.Module.register_parameter = old_register_parameter - for ( - torch_function_name, - old_torch_function, - ) in tensor_constructors_to_patch.items(): - setattr(torch, torch_function_name, old_torch_function) - - -class MultiModalProcessingInfo(BaseProcessingInfo): - def get_supported_mm_limits(self): - return {"image": None} - - def get_mm_max_tokens_per_item(self, seq_len, mm_counts): - return {"image": self.get_max_image_tokens()} - - def get_max_image_tokens(self) -> int: - width, height = self.get_max_image_size() - processor = self.get_hf_processor() - multimodal_config = self.ctx.model_config.multimodal_config - mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {} - mm_tokens = processor._get_num_multimodal_tokens( - image_sizes=([height, width],), **mm_processor_kwargs - ) - image_tokens = mm_tokens["num_image_tokens"][0] - return image_tokens - - def get_max_image_size(self): - return 10_000, 10_000 # hardcode for arbitrary very large size - - -class MultiModalDummyInputsBuilder(BaseDummyInputsBuilder[MultiModalProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: - num_images = mm_counts.get("image", 0) - - processor = self.info.get_hf_processor() - if "gemma3" in processor.__class__.__name__.lower(): - image_token = processor.boi_token - else: - image_token = getattr(processor, "image_token", "") - return image_token * num_images - - def get_dummy_mm_data( - self, - seq_len: int, - mm_counts: Mapping[str, int], - mm_options: Mapping[str, BaseDummyOptions] | None = None, - ) -> MultiModalDataDict: - num_images = mm_counts.get("image", 0) - - target_width, target_height = self.info.get_max_image_size() - - image_overrides = mm_options.get("image") if mm_options else None - - return { - "image": self._get_dummy_images( - width=target_width, - height=target_height, - num_images=num_images, - overrides=image_overrides, - ), - } - - -class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): - def _get_prompt_updates( - self, - mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargsItems, - ): - """ - Given the original multi-modal items for this modality - and HF-processed data, output the updates to perform. - - The information returned by this method is used to update token inputs - which bypass the HF processor. It is also used to update the output of - HF processor if the HF process does not apply prompt updates to text - inputs. - - Moreover, this information is critical to determine the token positions - in order to construct :class:`~vllm-multimodal.input.PlaceholderRange` - for each multi-modal item. - """ - return None - - def _get_mm_fields_config( - self, - hf_inputs: BatchFeature, - hf_processor_mm_kwargs: Mapping[str, object], - ) -> Mapping[str, MultiModalFieldConfig]: - # HF Processors always return a mask but vLLM doesn't need it - hf_inputs.pop("attention_mask", None) - num_image_patches = hf_inputs.get("num_image_patches") - mm_fields = { - key: MultiModalFieldConfig.flat_from_sizes("image", num_image_patches) - for key in hf_inputs - } - mm_fields["image_embeds"] = MultiModalFieldConfig.flat_from_sizes( - "image", num_image_patches - ) - - # Keep these as batched, as they always have batch size as first dim - mm_fields["image_grid_thw"] = MultiModalFieldConfig.batched("image") - mm_fields["video_grid_thw"] = MultiModalFieldConfig.batched("image") - mm_fields["num_image_patches"] = MultiModalFieldConfig.batched("image") - return mm_fields - - def _get_hf_mm_data( - self, - mm_items: MultiModalDataItems, - ) -> tuple[Mapping[str, object], Mapping[str, object]]: - """ - In contrast to the base class, this method always adds - `return_mm_token_type_ids` to the processor data - """ - processor_data, passthrough_data = super()._get_hf_mm_data(mm_items) - processor_data["return_mm_token_type_ids"] = True - return processor_data, passthrough_data - - def apply( - self, - prompt: str | list[int], - mm_data: MultiModalDataDict, - hf_processor_mm_kwargs: Mapping[str, object], - tokenization_kwargs: Mapping[str, object] | None = None, - mm_uuids: MultiModalUUIDDict | None = None, - ) -> MultiModalInputs: - """ - Process multi-modal inputs to be used in vLLM. - - Apply HF Processor on prompt text and multi-modal data together, - outputting token IDs and processed tensors. - """ - if tokenization_kwargs is None: - tokenization_kwargs = {} - - mm_items = self._to_mm_items(mm_data) - hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) - if not isinstance(prompt, str): - # the prompt is the tokenized ids which is not supported - # by the hf_processor, which is why we would need to decode the ids - # into string - prompt = hf_processor.decode(prompt) - - # Bypass cached processor and always apply to the full set of mm inputs - # NOTE: we can't just set caching=False because base class method - # transforms outputs to `MultiModalKwargs` which is not going to - # work for Transformers. We have a lot of logic tied to - # `mm_tokens_per_modality` below - prompt_ids, processed_data, _ = self._apply_hf_processor_text_mm( - prompt_text=prompt, - mm_items=mm_items, - hf_processor_mm_kwargs=hf_processor_mm_kwargs, - tokenization_kwargs=tokenization_kwargs, - ) - - # For gemma3 we check `token_type_ids` as the key - token_type_key = ( - "mm_token_type_ids" - if "mm_token_type_ids" in processed_data - else "token_type_ids" - ) - mm_token_type_ids = processed_data.pop(token_type_key) - - # We can infer vLLM style placeholder from token type ids, if we split - # it for each input `mm_data`. - mm_positions = torch.where(mm_token_type_ids == 1)[1] - images = mm_items.get_items("image", ImageProcessorItems) - multimodal_config = self.info.ctx.model_config.multimodal_config - mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {} - image_sizes = [] - for item_idx in range(len(images)): - image_size = images.get_image_size(item_idx) - image_sizes.append((image_size.height, image_size.width)) - - mm_tokens_per_modality = hf_processor._get_num_multimodal_tokens( - image_sizes=image_sizes, **mm_processor_kwargs - ) - - mm_placeholders = {} - split_sizes = mm_tokens_per_modality["num_image_tokens"] - if split_sizes: - chunked_mm_positions = torch.split(mm_positions, split_sizes) - mm_tokens = torch.tensor(prompt_ids)[mm_token_type_ids[0].bool()] - chunked_mm_tokens = torch.split(mm_tokens, split_sizes) - ranges = [ - PlaceholderRange( - offset=positions[0].item(), - length=positions.shape[0], - is_embed=(mm_tokens == hf_processor.image_token_id).bool(), - ) - for positions, mm_tokens in zip(chunked_mm_positions, chunked_mm_tokens) - ] - mm_placeholders = {"image": ranges} - - processed_data["num_image_patches"] = torch.tensor( - mm_tokens_per_modality["num_image_patches"] - ) - mm_kwargs = MultiModalKwargsItems.from_hf_inputs( - processed_data, - self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs), - ) - - # Use overrides if provided; fallback to data-dependent hashing. - mm_hashes = self._hash_mm_items( - mm_items, hf_processor_mm_kwargs, tokenization_kwargs, mm_uuids=mm_uuids - ) - - return MultiModalInputs( - type="multimodal", - prompt_token_ids=prompt_ids, - mm_kwargs=mm_kwargs, - mm_hashes=mm_hashes, - mm_placeholders=mm_placeholders, - ) - - -class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): - embedding_padding_modules = ["lm_head"] - embedding_modules = ["embed_tokens"] # TODO transformers will have a util to get it - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - logger.info("Using Transformers backend.") - - self.config: PretrainedConfig = vllm_config.model_config.hf_config - self.text_config: PretrainedConfig = self.config.get_text_config() - self.cache_config: CacheConfig = vllm_config.cache_config - self.device_config: DeviceConfig = vllm_config.device_config - self.model_config: ModelConfig = vllm_config.model_config - self.parallel_config: ParallelConfig = vllm_config.parallel_config - self.quant_config: QuantizationConfig | None = vllm_config.quant_config - - self.pp_group = get_pp_group() - self.tp_group = get_tp_group() - - # Weights to skip in `self.load_weights` - self.skip_prefixes: list[str] = [] - """Skip loading weights whose qualname starts with these prefixes.""" - self.skip_substrs: list[str] = [] - """Skip loading weights whose qualname contains these substrings.""" - self.ignore_unexpected_prefixes: list[str] = [] - """Ignore unexpected weights whose qualname starts with these prefixes. - """ - self.ignore_unexpected_suffixes: list[str] = [] - """Ignore unexpected weights whose qualname ends with these suffixes.""" - - if self.quant_config: - quant_method_name = self.quant_config.get_name() - # Check for unsupported quantization methods. - if quant_method_name == "mxfp4": - raise NotImplementedError( - "Transformers backend does not support MXFP4 quantization yet." - ) - # Skip loading extra bias for GPTQ models. - if "gptq" in quant_method_name: - self.ignore_unexpected_suffixes.append(".bias") - - # Set correct attn and init on "meta" to delay allocating GPU tensors - self.text_config._attn_implementation = "vllm" - with init_on_device_without_buffers("meta"): - self.model: PreTrainedModel = AutoModel.from_config( - self.config, - dtype=self.model_config.dtype, - trust_remote_code=self.model_config.trust_remote_code, - ) - - # Remove layers not on this pipeline parallel rank - self.pipeline_parallel() - # Substitute remaining layers with vLLM's layers as needed - self.recursive_replace() - # Create attention instances for KV cache allocation - self.attention_instances = self.create_attention_instances() - - # Input embeddings - input_embeddings = self.model.get_input_embeddings() - if not isinstance(input_embeddings, PPMissingLayer): - # Some models use embedding scales - self.embed_scale = getattr(input_embeddings, "embed_scale", None) - names = ("embedding_size", "hidden_size") - embedding_dim = getattr_iter(self.text_config, names, None) - assert embedding_dim is not None - self.model.set_input_embeddings( - VocabParallelEmbedding( - self.text_config.vocab_size, - embedding_dim=embedding_dim, - org_num_embeddings=self.text_config.vocab_size, - quant_config=self.quant_config, - ) - ) - - # Initialize any parameters that have not had their modules replaced - self.init_parameters(self.model) - - # Pipeline parallel intermediate tensors - self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( - ["hidden_states"], self.text_config.hidden_size - ) - - def pipeline_parallel(self): - """ - Apply the model's pipeline parallelization plan. - """ - if self.pp_group.world_size <= 1: - return - - if not self.model.supports_pp_plan: - tip = get_feature_request_tip( - self.model_config.model, self.model_config.trust_remote_code - ) - raise ValueError( - f"{type(self.model)} does not support pipeline parallel. {tip}" - ) - - module_lists = [] - module_list_idx = None - pp_plan = list(self.model._pp_plan.keys()) - for i, name in enumerate(pp_plan): - if isinstance(getattr(self.model, name), nn.ModuleList): - module_lists.append(name) - module_list_idx = i - - if len(module_lists) > 1: - raise ValueError( - "Pipeline parallel of models with multiple `ModuleList`s " - "in the base model are not supported yet!" - ) - if module_list_idx is None: - raise ValueError(f"Could not find `ModuleList` in {type(self.model)}") - - # Layers before module list - for name in pp_plan[:module_list_idx]: - if self.pp_group.is_first_rank or ( - self.text_config.tie_word_embeddings and self.pp_group.is_last_rank - ): - continue - setattr(self.model, name, PPMissingLayer()) - - # Module list - start_layer, end_layer = get_pp_indices( - self.text_config.num_hidden_layers, - self.pp_group.rank_in_group, - self.pp_group.world_size, - ) - layers_name = pp_plan[module_list_idx] - layers = getattr(self.model, layers_name) - for i in range(len(layers)): - if start_layer <= i and i < end_layer: - continue - layers[i] = PPMissingLayer() - - # Layers after module list - for name in pp_plan[module_list_idx + 1 :]: - # Modules that should be on last rank - if not self.pp_group.is_last_rank: - setattr(self.model, name, PPMissingLayer()) - - def recursive_replace(self): - """Recursively replace modules in the model as needed. - - Currently, this replaces: - - - `nn.Linear` with vLLM's tensor parallel linear classes - - `*RMSNorm` with vLLM's `RMSNorm` - """ - tp_plan = self.model.tp_plan - - if not tp_plan and self.tp_group.world_size > 1: - tip = get_feature_request_tip( - self.model_config.model, self.model_config.trust_remote_code - ) - raise ValueError( - f"{type(self.model)} does not support tensor parallel. {tip}" - ) - - # Prefix the patterns because we always start from `self.model` - tp_plan = {maybe_prefix("model", k): v for k, v in tp_plan.items()} - - def _recursive_replace(module: nn.Module, prefix: str): - for child_name, child_module in module.named_children(): - new_module = child_module - qual_name = maybe_prefix(prefix, child_name) - if isinstance(child_module, nn.Linear): - generator = (p for p in tp_plan if re.match(p, qual_name)) - pattern = next(generator, None) - # Some weight loaders expect all linear layers to inherit - # LinearBase, so we set a default style which causes any - # unspecified layers to be replaced with ReplicatedLinear - style = tp_plan.get(pattern, "replicate") - new_module = replace_linear_class( - child_module, style, self.quant_config, prefix=qual_name - ) - elif child_module.__class__.__name__.endswith("RMSNorm"): - new_module = replace_rms_norm_class( - child_module, self.text_config.hidden_size - ) - else: - _recursive_replace(child_module, prefix=qual_name) - - if new_module is not child_module: - setattr(module, child_name, new_module) - log_replacement(qual_name, child_module, new_module) - - _recursive_replace(self.model, prefix="model") - - def create_attention_instances( - self, attn_type: AttentionType = AttentionType.DECODER - ) -> dict[int, Attention]: - """ - Create `Attention` instances to inform KV cache allocation. - """ - num_heads = self.model_config.get_num_attention_heads(self.parallel_config) - head_size = self.model_config.get_head_size() - num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) - logits_soft_cap = getattr(self.text_config, "attn_logit_softcapping", None) - start, end = get_pp_indices( - self.text_config.num_hidden_layers, - self.pp_group.rank_in_group, - self.pp_group.world_size, - ) - - attention_instances = {} - for i in range(start, end): - # Handle interleaved sliding window attention - per_layer_sliding_window = None - if ( - hasattr(self.config, "layer_types") - and self.config.layer_types[i] == "sliding_attention" - ): - per_layer_sliding_window = self.config.sliding_window - - attention_instances[i] = Attention( - num_heads=num_heads, - head_size=head_size, - # NOTE: We use Llama scale as default, if it's set by - # Transformers, it's updated in vllm_flash_attention_forward - scale=head_size**-0.5, - num_kv_heads=num_kv_heads, - cache_config=self.cache_config, - quant_config=self.quant_config, - logits_soft_cap=logits_soft_cap, - per_layer_sliding_window=per_layer_sliding_window, - prefix=f"{i}.attn", - attn_type=attn_type, - ) - return attention_instances - - def init_parameters(self, module: nn.Module, dtype: torch.dtype | None = None): - """ - If a `parameter` is on the `meta` device, then its parent - `module` is the original module created by: - - ```python - with torch.device("meta"): - self.model: PreTrainedModel = AutoModel.from_config(...) - ``` - """ - - def _init_parameters(module: nn.Module, dtype: torch.dtype | None): - for name, param in module.named_parameters(recurse=False): - if param.device == torch.device("meta"): - new_param = nn.Parameter( - torch.empty_like( - param.data, - dtype=dtype or self.model_config.dtype, - device=self.device_config.device, - ) - ) - setattr(module, name, new_param) - for child in module.children(): - _init_parameters(child, dtype) - - _init_parameters(module, dtype) - - def forward( - self, - input_ids: torch.Tensor | None, - positions: torch.Tensor, - intermediate_tensors: IntermediateTensors | None = None, - inputs_embeds: torch.Tensor | None = None, - **kwargs, - ) -> torch.Tensor | IntermediateTensors: - if not self.pp_group.is_first_rank: - assert intermediate_tensors is not None - input_ids = None - inputs_embeds = intermediate_tensors["hidden_states"] - - if input_ids is not None: - input_ids = input_ids[None, ...] - if inputs_embeds is not None: - inputs_embeds = inputs_embeds[None, ...] - - if self.model_config.uses_mrope: - position_ids = positions[:, None] - else: - position_ids = positions[None, ...] - - hidden_states = self.model( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - use_cache=False, - position_ids=position_ids, - attention_instances=self.attention_instances, - return_dict=False, - **kwargs, - )[0][0, ...] # we remove batch dimension for now - - if not self.pp_group.is_last_rank: - return IntermediateTensors({"hidden_states": hidden_states}) - - return hidden_states - - def load_weights( - self, - weights: Iterable[tuple[str, torch.Tensor]], - ) -> set[str]: - loader = AutoWeightsLoader( - self, - skip_prefixes=self.skip_prefixes, - skip_substrs=self.skip_substrs, - ignore_unexpected_prefixes=self.ignore_unexpected_prefixes, - ignore_unexpected_suffixes=self.ignore_unexpected_suffixes, - ) - return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) - - def check_version(self, min_version: str, feature: str): - installed = Version(transformers.__version__) - required = Version(min_version) - if installed < required: - raise ImportError( - f"Transformers backend requires transformers>={required} " - f"for {feature}, but got {installed}" - ) - - -@support_torch_compile(enable_if=can_enable_torch_compile) -class TransformersForCausalLM(TransformersBase): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, prefix=prefix) - - # Tell `TransformersBase.load_weights` to skip - # `lm_head` if the model has tied word embeddings - if self.text_config.tie_word_embeddings: - self.skip_prefixes.append("lm_head.") - - if self.pp_group.is_last_rank: - self.unpadded_vocab_size = self.text_config.vocab_size - self.lm_head = ParallelLMHead( - self.text_config.vocab_size, - self.text_config.hidden_size, - quant_config=self.quant_config, - prefix=maybe_prefix(prefix, "lm_head"), - ) - if self.text_config.tie_word_embeddings: - self.lm_head = self.lm_head.tie_weights( - self.model.get_input_embeddings() - ) - - logit_scale = getattr(self.text_config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, self.text_config.vocab_size, logit_scale - ) - else: - self.lm_head = PPMissingLayer() - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - inputs_embeds = self.model.get_input_embeddings()(input_ids) - if self.embed_scale is not None: - inputs_embeds *= self.embed_scale - return inputs_embeds - - def compute_logits( - self, - hidden_states: torch.Tensor, - ) -> torch.Tensor | None: - logits = self.logits_processor(self.lm_head, hidden_states) - return logits - - -@MULTIMODAL_REGISTRY.register_processor( - MultiModalProcessor, - info=MultiModalProcessingInfo, - dummy_inputs=MultiModalDummyInputsBuilder, -) -@support_torch_compile( - # set `positions` to last dim to support Qwen-mrope - dynamic_arg_dims={ - "input_ids": 0, - "positions": -1, - "intermediate_tensors": 0, - "inputs_embeds": 0, - }, - enable_if=can_enable_torch_compile, -) -class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal): - supports_multimodal_raw_input_only = True - merge_by_field_config = True - # Backwards compatibility for prev released models. State dicts back then - # had different formats and cannot be loaded with `AutoModel` mapping as is - hf_to_vllm_mapper = WeightsMapper( - orig_to_new_prefix={ - "language_model.model": "model.language_model", - "text_model.model": "model.text_model", - "vision_tower": "model.vision_tower", - "vqmodel": "model.vqmodel", - "visual": "model.visual", - "vision_model": "model.vision_model", - "vision_embed_tokens": "model.vision_embed_tokens", - "image_newline": "model.image_newline", - "multi_modal_projector": "model.multi_modal_projector", - "text_model.lm_head": "lm_head", - "language_model.lm_head": "lm_head", - # Qwen models used "model" as the name for the language model. - # Therefore, we must map each of submodule explicitly to avoid - # conflicts with newer models that use "model.language_model". - "model.embed_tokens": "model.language_model.embed_tokens", - "model.layers": "model.language_model.layers", - "model.norm": "model.language_model.norm", - } - ) - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, prefix=prefix) - - self.dtype = vllm_config.model_config.dtype - - def forward( - self, - input_ids: torch.Tensor | None, - positions: torch.Tensor, - intermediate_tensors: IntermediateTensors | None = None, - inputs_embeds: torch.Tensor | None = None, - **kwargs: object, - ) -> torch.Tensor | IntermediateTensors: - # Gemma3 and PaliGemma needs `token_type_ids` to work correctly - # Other models will not have `token_type_ids` in kwargs - kwargs = {k: v for k, v in kwargs.items() if k == "token_type_ids"} - model_output = super().forward( - input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs - ) - return model_output - - def get_language_model(self) -> torch.nn.Module: - """`TransformersForMultimodalLM` does not contain a vLLM language model class. - Therefore, in order to return a language model vLLM class, we use a wrapper to - give `self` the same interface as `TransformersForCausalLM`.""" - - class LanguageModelWrapper(TransformersForCausalLM): - def __init__(self, multimodal_model): - # Don't call super().__init__() to avoid re-initialization - self.__dict__.update(multimodal_model.__dict__) - - model = getattr_iter(self.model, ("language_model", "text_model"), None) - - return LanguageModelWrapper(self) - - def get_multimodal_embeddings(self, **kwargs): - pixel_values: torch.Tensor | None = kwargs.pop("pixel_values", None) - image_embeds: torch.Tensor | None = kwargs.pop("image_embeds", None) - # Model might use `image_patches` instead of `pixel_values` - if pixel_values is None: - pixel_values = kwargs.pop("image_patches", None) - - if image_embeds is not None: - return image_embeds - - if pixel_values is None: - return None - - num_image_patches = kwargs.pop("num_image_patches") - kwargs.pop("token_type_ids", None) # used only in `forward` - if pixel_values is not None: - vision_embeddings = self.model.get_image_features(pixel_values, **kwargs) - - if isinstance(vision_embeddings, torch.Tensor): - if vision_embeddings.ndim == 2: - vision_embeddings = vision_embeddings.unsqueeze(0) - - # Embeddings have to be 2D tensors of length `num_images` - # but transformers returns concat tensors if each patch - # is of different size. We split it back to make vLLM happy - vision_embeddings = torch.split( - vision_embeddings, num_image_patches.flatten().tolist() - ) - vision_embeddings = [ - embed.flatten(start_dim=0, end_dim=-2) - for embed in vision_embeddings - ] - - return vision_embeddings - - get_input_embeddings = SupportsMultiModal.get_input_embeddings diff --git a/vllm/model_executor/models/transformers/__init__.py b/vllm/model_executor/models/transformers/__init__.py new file mode 100644 index 0000000000..365b5eb088 --- /dev/null +++ b/vllm/model_executor/models/transformers/__init__.py @@ -0,0 +1,127 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2024 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Wrapper around `transformers` models""" + +from vllm.compilation.decorators import support_torch_compile +from vllm.model_executor.models.transformers.base import Base +from vllm.model_executor.models.transformers.causal import CausalMixin +from vllm.model_executor.models.transformers.legacy import LegacyMixin +from vllm.model_executor.models.transformers.moe import MoEMixin +from vllm.model_executor.models.transformers.multimodal import ( + DYNAMIC_ARG_DIMS, + MultiModalDummyInputsBuilder, + MultiModalMixin, + MultiModalProcessingInfo, + MultiModalProcessor, +) +from vllm.model_executor.models.transformers.pooling import ( + EmbeddingMixin, + SequenceClassificationMixin, +) +from vllm.model_executor.models.transformers.utils import can_enable_torch_compile +from vllm.multimodal import MULTIMODAL_REGISTRY + + +# Text only models +@support_torch_compile(enable_if=can_enable_torch_compile) +class TransformersForCausalLM(CausalMixin, Base): ... + + +@support_torch_compile(enable_if=can_enable_torch_compile) +class TransformersMoEForCausalLM(MoEMixin, CausalMixin, Base): ... + + +# Multimodal models +@MULTIMODAL_REGISTRY.register_processor( + MultiModalProcessor, + info=MultiModalProcessingInfo, + dummy_inputs=MultiModalDummyInputsBuilder, +) +@support_torch_compile( + dynamic_arg_dims=DYNAMIC_ARG_DIMS, enable_if=can_enable_torch_compile +) +class TransformersMultiModalForCausalLM(MultiModalMixin, CausalMixin, Base): ... + + +@MULTIMODAL_REGISTRY.register_processor( + MultiModalProcessor, + info=MultiModalProcessingInfo, + dummy_inputs=MultiModalDummyInputsBuilder, +) +@support_torch_compile( + dynamic_arg_dims=DYNAMIC_ARG_DIMS, enable_if=can_enable_torch_compile +) +class TransformersMultiModalMoEForCausalLM( + MoEMixin, MultiModalMixin, CausalMixin, Base +): ... + + +# Embedding models +@support_torch_compile(enable_if=can_enable_torch_compile) +class TransformersEmbeddingModel(EmbeddingMixin, LegacyMixin, Base): ... + + +@support_torch_compile(enable_if=can_enable_torch_compile) +class TransformersMoEEmbeddingModel(EmbeddingMixin, MoEMixin, Base): ... + + +@MULTIMODAL_REGISTRY.register_processor( + MultiModalProcessor, + info=MultiModalProcessingInfo, + dummy_inputs=MultiModalDummyInputsBuilder, +) +@support_torch_compile( + dynamic_arg_dims=DYNAMIC_ARG_DIMS, enable_if=can_enable_torch_compile +) +class TransformersMultiModalEmbeddingModel(EmbeddingMixin, MultiModalMixin, Base): ... + + +# Sequence classification models +@support_torch_compile(enable_if=can_enable_torch_compile) +class TransformersForSequenceClassification( + SequenceClassificationMixin, LegacyMixin, Base +): ... + + +@support_torch_compile(enable_if=can_enable_torch_compile) +class TransformersMoEForSequenceClassification( + SequenceClassificationMixin, MoEMixin, Base +): ... + + +@MULTIMODAL_REGISTRY.register_processor( + MultiModalProcessor, + info=MultiModalProcessingInfo, + dummy_inputs=MultiModalDummyInputsBuilder, +) +@support_torch_compile( + dynamic_arg_dims=DYNAMIC_ARG_DIMS, enable_if=can_enable_torch_compile +) +class TransformersMultiModalForSequenceClassification( + SequenceClassificationMixin, MultiModalMixin, Base +): ... + + +def __getattr__(name: str): + """Handle imports of non-existent classes with a helpful error message.""" + if name not in globals(): + raise AttributeError( + "The Transformers backend does not currently have a class to handle " + f"the requested model type: {name}. Please open an issue at " + "https://github.com/vllm-project/vllm/issues/new" + ) + return globals()[name] diff --git a/vllm/model_executor/models/transformers/base.py b/vllm/model_executor/models/transformers/base.py new file mode 100644 index 0000000000..d940bb9739 --- /dev/null +++ b/vllm/model_executor/models/transformers/base.py @@ -0,0 +1,435 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2024 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Transformers backend base class.""" + +from collections.abc import Iterable +from typing import TYPE_CHECKING + +import regex as re +import torch +import transformers +from packaging.version import Version +from torch import nn +from transformers import AutoModel +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + +from vllm.attention import Attention, AttentionType +from vllm.config.utils import getattr_iter +from vllm.distributed import get_pp_group, get_tp_group +from vllm.distributed.utils import get_pp_indices +from vllm.logger import init_logger +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding +from vllm.model_executor.models.interfaces import ( + SupportsLoRA, + SupportsPP, + SupportsQuant, +) +from vllm.model_executor.models.interfaces_base import VllmModel +from vllm.model_executor.models.transformers.utils import ( + get_feature_request_tip, + init_on_device_without_buffers, + log_replacement, + replace_linear_class, + replace_rms_norm_class, +) +from vllm.model_executor.models.utils import ( + AutoWeightsLoader, + PPMissingLayer, + make_empty_intermediate_tensors_factory, + maybe_prefix, +) +from vllm.sequence import IntermediateTensors + +if TYPE_CHECKING: + from transformers import PreTrainedModel + + from vllm.config import VllmConfig +else: + PreTrainedModel = object + +logger = init_logger(__name__) + + +def vllm_flash_attention_forward( + # Transformers args + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor, + # Transformers kwargs + scaling: float | None = None, + # vLLM kwargs + attention_instances: dict[int, Attention] | None = None, + **kwargs, +): + self_attn = attention_instances[module.layer_idx] + if scaling is not None: + self_attn.impl.scale = float(scaling) + hidden = query.shape[-2] + query, key, value = (x.transpose(1, 2) for x in (query, key, value)) + query, key, value = (x.reshape(hidden, -1) for x in (query, key, value)) + return self_attn.forward(query, key, value), None + + +ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward + + +class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP): + embedding_padding_modules = ["lm_head"] + embedding_modules = ["embed_tokens"] # TODO transformers will have a util to get it + + def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""): + super().__init__() + logger.info("Using Transformers backend.") + + self.config = vllm_config.model_config.hf_config + self.text_config = self.config.get_text_config() + self.cache_config = vllm_config.cache_config + self.device_config = vllm_config.device_config + self.model_config = vllm_config.model_config + self.parallel_config = vllm_config.parallel_config + self.quant_config = vllm_config.quant_config + + self.pp_group = get_pp_group() + self.tp_group = get_tp_group() + + # Weights to skip in `self.load_weights` + self.skip_prefixes: list[str] = [] + """Skip loading weights whose qualname starts with these prefixes.""" + self.skip_substrs: list[str] = [] + """Skip loading weights whose qualname contains these substrings.""" + self.ignore_unexpected_prefixes: list[str] = [] + """Ignore unexpected weights whose qualname starts with these prefixes. + """ + self.ignore_unexpected_suffixes: list[str] = [] + """Ignore unexpected weights whose qualname ends with these suffixes.""" + + if self.quant_config: + quant_method_name = self.quant_config.get_name() + # Check for unsupported quantization methods. + if quant_method_name == "mxfp4": + raise NotImplementedError( + "Transformers backend does not support MXFP4 quantization yet." + ) + # Skip loading extra bias for GPTQ models. + if "gptq" in quant_method_name: + self.ignore_unexpected_suffixes.append(".bias") + + # Set correct attn and init on "meta" to delay allocating GPU tensors + self.text_config._attn_implementation = "vllm" + with init_on_device_without_buffers("meta"): + self.model: PreTrainedModel = AutoModel.from_config( + self.config, + dtype=self.model_config.dtype, + trust_remote_code=self.model_config.trust_remote_code, + ) + + # Remove layers not on this pipeline parallel rank + self.pipeline_parallel() + # Substitute remaining layers with vLLM's layers as needed + self.recursive_replace() + # Create attention instances for KV cache allocation + self.attention_instances = self.create_attention_instances() + + # Input embeddings + input_embeddings = self.model.get_input_embeddings() + if not isinstance(input_embeddings, PPMissingLayer): + # Some models scale embeddings inside the input embedding layer + self.embed_scale = getattr(input_embeddings, "embed_scale", None) + names = ("embedding_size", "hidden_size") + embedding_dim = getattr_iter(self.text_config, names, None) + assert embedding_dim is not None + self.model.set_input_embeddings( + VocabParallelEmbedding( + self.text_config.vocab_size, + embedding_dim=embedding_dim, + org_num_embeddings=self.text_config.vocab_size, + quant_config=self.quant_config, + ) + ) + + # Initialize any parameters that have not had their modules replaced + self.init_parameters(self.model) + + # Pipeline parallel intermediate tensors + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states"], self.text_config.hidden_size + ) + + def pipeline_parallel(self): + """ + Apply the model's pipeline parallelization plan. + """ + if self.pp_group.world_size <= 1: + return + + if not self.model.supports_pp_plan: + tip = get_feature_request_tip( + self.model_config.model, self.model_config.trust_remote_code + ) + raise ValueError( + f"{type(self.model)} does not support pipeline parallel. {tip}" + ) + + module_lists = [] + module_list_idx = None + pp_plan = list(self.model._pp_plan.keys()) + for i, name in enumerate(pp_plan): + if isinstance(getattr(self.model, name), nn.ModuleList): + module_lists.append(name) + module_list_idx = i + + if len(module_lists) > 1: + raise ValueError( + "Pipeline parallel of models with multiple `ModuleList`s " + "in the base model are not supported yet!" + ) + if module_list_idx is None: + raise ValueError(f"Could not find `ModuleList` in {type(self.model)}") + + # Layers before module list + for name in pp_plan[:module_list_idx]: + if self.pp_group.is_first_rank or ( + self.text_config.tie_word_embeddings and self.pp_group.is_last_rank + ): + continue + setattr(self.model, name, PPMissingLayer()) + + # Module list + start_layer, end_layer = get_pp_indices( + self.text_config.num_hidden_layers, + self.pp_group.rank_in_group, + self.pp_group.world_size, + ) + layers_name = pp_plan[module_list_idx] + layers = getattr(self.model, layers_name) + for i in range(len(layers)): + if start_layer <= i and i < end_layer: + continue + layers[i] = PPMissingLayer() + + # Layers after module list + for name in pp_plan[module_list_idx + 1 :]: + # Modules that should be on last rank + if not self.pp_group.is_last_rank: + setattr(self.model, name, PPMissingLayer()) + + def recursive_replace(self): + """Recursively replace modules in the model as needed. + + Currently, this replaces: + + - `nn.Linear` with vLLM's tensor parallel linear classes + - `*RMSNorm` with vLLM's `RMSNorm` + """ + tp_plan = self.model.tp_plan + + if not tp_plan and self.tp_group.world_size > 1: + tip = get_feature_request_tip( + self.model_config.model, self.model_config.trust_remote_code + ) + raise ValueError( + f"{type(self.model)} does not support tensor parallel. {tip}" + ) + + # Prefix the patterns because we always start from `self.model` + tp_plan = {maybe_prefix("model", k): v for k, v in tp_plan.items()} + + def _recursive_replace(module: nn.Module, prefix: str): + for child_name, child_module in module.named_children(): + new_module = child_module + qual_name = maybe_prefix(prefix, child_name) + if isinstance(child_module, nn.Linear): + generator = (p for p in tp_plan if re.match(p, qual_name)) + pattern = next(generator, None) + # Some weight loaders expect all linear layers to inherit + # LinearBase, so we set a default style which causes any + # unspecified layers to be replaced with ReplicatedLinear + style = tp_plan.get(pattern, "replicate") + new_module = replace_linear_class( + child_module, style, self.quant_config, prefix=qual_name + ) + elif child_module.__class__.__name__.endswith("RMSNorm"): + new_module = replace_rms_norm_class( + child_module, self.text_config.hidden_size + ) + else: + _recursive_replace(child_module, prefix=qual_name) + + if new_module is not child_module: + setattr(module, child_name, new_module) + log_replacement(qual_name, child_module, new_module) + + _recursive_replace(self.model, prefix="model") + + def create_attention_instances(self) -> dict[int, Attention]: + """ + Create `Attention` instances to inform KV cache allocation. + """ + text_config = self.text_config + + num_heads = self.model_config.get_num_attention_heads(self.parallel_config) + head_size = self.model_config.get_head_size() + num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) + logits_soft_cap = getattr(text_config, "attn_logit_softcapping", None) + + # In encoder models, the attention layers will have `is_causal=False` + is_encoder = lambda module: not getattr(module, "is_causal", True) + has_encoder = lambda model: any(is_encoder(m) for m in model.modules()) + is_multimodal = lambda config: config != config.get_text_config() + # vLLM does not support encoder-decoder models, so if any encoder layer is + # found in a text only model, we assume the whole model is an encoder model + if has_encoder(self.model) and not is_multimodal(self.config): + self.check_version("4.57.0.dev0", "encoder models support") + attn_type = AttentionType.ENCODER_ONLY + else: + attn_type = AttentionType.DECODER + + pp_rank = self.pp_group.rank_in_group + pp_size = self.pp_group.world_size + start, end = get_pp_indices(text_config.num_hidden_layers, pp_rank, pp_size) + + attention_instances = {} + for i in range(start, end): + # Handle interleaved sliding window attention + per_layer_sliding_window = None + if ( + hasattr(self.config, "layer_types") + and self.config.layer_types[i] == "sliding_attention" + ): + per_layer_sliding_window = self.config.sliding_window + + attention_instances[i] = Attention( + num_heads=num_heads, + head_size=head_size, + # NOTE: We use Llama scale as default, if it's set by + # Transformers, it's updated in vllm_flash_attention_forward + scale=head_size**-0.5, + num_kv_heads=num_kv_heads, + cache_config=self.cache_config, + quant_config=self.quant_config, + logits_soft_cap=logits_soft_cap, + per_layer_sliding_window=per_layer_sliding_window, + prefix=f"{i}.attn", + attn_type=attn_type, + ) + return attention_instances + + def init_parameters(self, module: nn.Module, dtype: torch.dtype | None = None): + """ + If a `parameter` is on the `meta` device, then its parent + `module` is the original module created by: + + ```python + with torch.device("meta"): + self.model: "PreTrainedModel" = AutoModel.from_config(...) + ``` + """ + + def _init_parameters(module: nn.Module, dtype: torch.dtype | None): + for name, param in module.named_parameters(recurse=False): + if param.device == torch.device("meta"): + new_param = nn.Parameter( + torch.empty_like( + param.data, + dtype=dtype or self.model_config.dtype, + device=self.device_config.device, + ) + ) + setattr(module, name, new_param) + for child in module.children(): + _init_parameters(child, dtype) + + _init_parameters(module, dtype) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + inputs_embeds = self.model.get_input_embeddings()(input_ids) + if self.embed_scale is not None: + inputs_embeds *= self.embed_scale + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor | IntermediateTensors: + if not self.pp_group.is_first_rank: + assert intermediate_tensors is not None + input_ids = None + inputs_embeds = intermediate_tensors["hidden_states"] + + if input_ids is not None: + input_ids = input_ids[None, ...] + if inputs_embeds is not None: + inputs_embeds = inputs_embeds[None, ...] + + # If the model scales embeddings inside the input embedding layer we must + # ensure they are scaled here since VocabParallelEmbedding will not do it + if ( + self.embed_scale is not None + and input_ids is not None + and inputs_embeds is None + ): + inputs_embeds = self.get_input_embeddings(input_ids) + input_ids = None + + if self.model_config.uses_mrope: + position_ids = positions[:, None] + else: + position_ids = positions[None, ...] + + hidden_states = self.model( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + use_cache=False, + position_ids=position_ids, + attention_instances=self.attention_instances, + return_dict=False, + **kwargs, + )[0][0, ...] # we remove batch dimension for now + + if not self.pp_group.is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + + return hidden_states + + def load_weights( + self, + weights: Iterable[tuple[str, torch.Tensor]], + ) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=self.skip_prefixes, + skip_substrs=self.skip_substrs, + ignore_unexpected_prefixes=self.ignore_unexpected_prefixes, + ignore_unexpected_suffixes=self.ignore_unexpected_suffixes, + ) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + @staticmethod + def check_version(min_version: str, feature: str): + installed = Version(transformers.__version__) + required = Version(min_version) + if installed < required: + raise ImportError( + f"Transformers backend requires transformers>={required} " + f"for {feature}, but got {installed}" + ) diff --git a/vllm/model_executor/models/transformers/causal.py b/vllm/model_executor/models/transformers/causal.py new file mode 100644 index 0000000000..7f7b15a567 --- /dev/null +++ b/vllm/model_executor/models/transformers/causal.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2024 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Transformers backend mixin for causal language models.""" + +from typing import TYPE_CHECKING + +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.models.interfaces_base import VllmModelForTextGeneration +from vllm.model_executor.models.utils import PPMissingLayer, maybe_prefix + +if TYPE_CHECKING: + import torch + + from vllm.config import VllmConfig + + +class CausalMixin(VllmModelForTextGeneration): + def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""): + # Skip VllmModelForTextGeneration.__init__ and call the next class in MRO + super(VllmModelForTextGeneration, self).__init__( + vllm_config=vllm_config, prefix=prefix + ) + + # Tell `Base.load_weights` to skip + # `lm_head` if the model has tied word embeddings + if self.text_config.tie_word_embeddings: + self.skip_prefixes.append("lm_head.") + + if self.pp_group.is_last_rank: + self.unpadded_vocab_size = self.text_config.vocab_size + self.lm_head = ParallelLMHead( + self.text_config.vocab_size, + self.text_config.hidden_size, + quant_config=self.quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + if self.text_config.tie_word_embeddings: + self.lm_head = self.lm_head.tie_weights( + self.model.get_input_embeddings() + ) + + logit_scale = getattr(self.text_config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, self.text_config.vocab_size, logit_scale + ) + else: + self.lm_head = PPMissingLayer() + + def compute_logits(self, hidden_states: "torch.Tensor") -> "torch.Tensor | None": + logits = self.logits_processor(self.lm_head, hidden_states) + return logits diff --git a/vllm/model_executor/models/transformers/legacy.py b/vllm/model_executor/models/transformers/legacy.py new file mode 100644 index 0000000000..5d4dcf0556 --- /dev/null +++ b/vllm/model_executor/models/transformers/legacy.py @@ -0,0 +1,97 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2024 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Transformers backend mixin for legacy models.""" + +from typing import TYPE_CHECKING + +import torch + +from vllm.model_executor.models.utils import WeightsMapper +from vllm.sequence import IntermediateTensors + +if TYPE_CHECKING: + from vllm.config import VllmConfig + + +class LegacyMixin: + hf_to_vllm_mapper = WeightsMapper( + # These are applied in order, so the order matters! + orig_to_new_prefix={ + # Handle BERT-like models + "roberta": "model", + "bert": "model", + # Add `model.` prefix for base model checkpoints + "": "model.", + # Remove `model.` prefix if it was already there + "model.model.": "model.", + # Classifier/scoring heads will be adjacent to `model` + "model.score": "classifier", + "model.classifier": "classifier", + }, + orig_to_new_suffix={ + # Replace legacy suffixes used for norms + ".gamma": ".weight", + ".beta": ".bias", + }, + ) + + def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + + # Skip unsupported/unwanted output embeddings layers + self.skip_prefixes.extend( + [ + "model.lm_head.", + "model.predictions.", + "model.qa_outputs.", + "model.embeddings_project.", + "model.discriminator_predictions.", + ] + ) + + # Some encoder models have the position_ids buffer in the checkpoint. + # vLLM will always pass position_ids as an argument, so we skip loading + # the buffer if it exists + self.skip_substrs.append("position_ids") + + # Some encoder models have the bias of the final classifier layer + # in the checkpoint. vLLM does not use this bias, so we skip loading + # it if it exists + self.skip_substrs.append("score.bias") + + # roberta-like models an extra padding in positions. + # FIXME(Isotr0py): This is quite hacky for roberta edge case, + # we should find a better way to handle this. + self.is_roberta = "roberta" in self.text_config.model_type + self.padding_idx = self.text_config.pad_token_id + + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: + if self.is_roberta: + # RoBERTa-specific positions padding + positions += self.padding_idx + 1 + return super().forward( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) diff --git a/vllm/model_executor/models/transformers_moe.py b/vllm/model_executor/models/transformers/moe.py similarity index 90% rename from vllm/model_executor/models/transformers_moe.py rename to vllm/model_executor/models/transformers/moe.py index 5267e44790..ed56fd7399 100644 --- a/vllm/model_executor/models/transformers_moe.py +++ b/vllm/model_executor/models/transformers/moe.py @@ -14,31 +14,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Wrapper around `transformers` MoE models.""" +"""Transformers backend mixin for Mixture of Experts (MoE) models.""" -from typing import Any +from typing import TYPE_CHECKING, Any import torch import torch.nn as nn -from vllm.compilation.decorators import support_torch_compile from vllm.config.utils import getattr_iter from vllm.distributed import get_dp_group, get_ep_group from vllm.forward_context import ForwardContext, get_forward_context from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.models.interfaces import MixtureOfExperts +from vllm.model_executor.models.utils import maybe_prefix from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op -from .interfaces import MixtureOfExperts, SupportsMultiModal -from .transformers import ( - TransformersBase, - TransformersForCausalLM, - TransformersForMultimodalLM, - can_enable_torch_compile, - log_replacement, -) -from .utils import maybe_prefix +from .utils import log_replacement + +if TYPE_CHECKING: + from vllm.config import VllmConfig @CustomOp.register("transformers_fused_moe") @@ -117,11 +113,11 @@ direct_register_custom_op( ) -class TransformersMoEBase(TransformersBase, MixtureOfExperts): - def __init__(self, *, vllm_config, prefix=""): +class MoEMixin(MixtureOfExperts): + def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""): self.check_version("4.57.0.dev0", "MoE models support") - self.ep_group = get_ep_group() - super().__init__(vllm_config=vllm_config, prefix=prefix) + # Skip MixtureOfExperts.__init__ and call the next class in MRO + super(MixtureOfExperts, self).__init__(vllm_config=vllm_config, prefix=prefix) def set_eplb_state( self, @@ -242,7 +238,7 @@ class TransformersMoEBase(TransformersBase, MixtureOfExperts): num_redundant_experts = self.parallel_config.eplb_config.num_redundant_experts # MixtureOfExperts mixin settings - ep_size = self.ep_group.world_size + ep_size = get_ep_group().world_size self.mlp_layers = [] # Used for MixtureOfExperts methods self.expert_weights = [] @@ -316,24 +312,5 @@ class TransformersMoEBase(TransformersBase, MixtureOfExperts): _recursive_replace(child_module, prefix=qual_name) _recursive_replace(self.model, prefix="model") - # Continue with the replacement of layers in TransformersBase + # Continue with the replacement of layers in Base super().recursive_replace() - - -@support_torch_compile(enable_if=can_enable_torch_compile) -class TransformersMoEForCausalLM(TransformersMoEBase, TransformersForCausalLM): - pass - - -@support_torch_compile( - # set `positions` to last dim to support Qwen-mrope - dynamic_arg_dims={ - "input_ids": 0, - "positions": -1, - "intermediate_tensors": 0, - "inputs_embeds": 0, - }, - enable_if=can_enable_torch_compile, -) -class TransformersMoEForMultimodalLM(TransformersMoEBase, TransformersForMultimodalLM): - get_input_embeddings = SupportsMultiModal.get_input_embeddings diff --git a/vllm/model_executor/models/transformers/multimodal.py b/vllm/model_executor/models/transformers/multimodal.py new file mode 100644 index 0000000000..10abd86595 --- /dev/null +++ b/vllm/model_executor/models/transformers/multimodal.py @@ -0,0 +1,396 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2024 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Transformers backend mixin for multi-modal models.""" + +from collections.abc import Mapping +from typing import TYPE_CHECKING + +import torch + +from vllm.config.utils import getattr_iter +from vllm.model_executor.models.interfaces import SupportsMRoPE, SupportsMultiModal +from vllm.model_executor.models.utils import WeightsMapper +from vllm.multimodal import MultiModalKwargsItems +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalInputs, + MultiModalUUIDDict, + PlaceholderRange, +) +from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems +from vllm.multimodal.processing import BaseMultiModalProcessor, BaseProcessingInfo +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors + +if TYPE_CHECKING: + from transformers import BatchFeature, PretrainedConfig + + from vllm.config import VllmConfig + from vllm.config.multimodal import BaseDummyOptions + +DYNAMIC_ARG_DIMS = { + "input_ids": 0, + # set `positions` to last dim to support Qwen-mrope + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, +} + + +class MultiModalProcessingInfo(BaseProcessingInfo): + def get_supported_mm_limits(self): + return {"image": None} + + def get_mm_max_tokens_per_item(self, seq_len, mm_counts): + return {"image": self.get_max_image_tokens()} + + def get_max_image_tokens(self) -> int: + width, height = self.get_max_image_size() + processor = self.get_hf_processor() + multimodal_config = self.ctx.model_config.multimodal_config + mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {} + mm_tokens = processor._get_num_multimodal_tokens( + image_sizes=([height, width],), **mm_processor_kwargs + ) + image_tokens = mm_tokens["num_image_tokens"][0] + return image_tokens + + def get_max_image_size(self): + return 10_000, 10_000 # hardcode for arbitrary very large size + + +class MultiModalDummyInputsBuilder(BaseDummyInputsBuilder[MultiModalProcessingInfo]): + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + + processor = self.info.get_hf_processor() + if "gemma3" in processor.__class__.__name__.lower(): + image_token = processor.boi_token + else: + image_token = getattr(processor, "image_token", "") + return image_token * num_images + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Mapping[str, "BaseDummyOptions"] | None = None, + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + + target_width, target_height = self.info.get_max_image_size() + + image_overrides = mm_options.get("image") if mm_options else None + + return { + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ), + } + + +class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ): + """ + Given the original multi-modal items for this modality + and HF-processed data, output the updates to perform. + + The information returned by this method is used to update token inputs + which bypass the HF processor. It is also used to update the output of + HF processor if the HF process does not apply prompt updates to text + inputs. + + Moreover, this information is critical to determine the token positions + in order to construct :class:`~vllm-multimodal.input.PlaceholderRange` + for each multi-modal item. + """ + return None + + def _get_mm_fields_config( + self, + hf_inputs: "BatchFeature", + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + # HF Processors always return a mask but vLLM doesn't need it + hf_inputs.pop("attention_mask", None) + num_image_patches = hf_inputs.get("num_image_patches") + mm_fields = { + key: MultiModalFieldConfig.flat_from_sizes("image", num_image_patches) + for key in hf_inputs + } + mm_fields["image_embeds"] = MultiModalFieldConfig.flat_from_sizes( + "image", num_image_patches + ) + + # Keep these as batched, as they always have batch size as first dim + mm_fields["image_grid_thw"] = MultiModalFieldConfig.batched("image") + mm_fields["video_grid_thw"] = MultiModalFieldConfig.batched("image") + mm_fields["num_image_patches"] = MultiModalFieldConfig.batched("image") + return mm_fields + + def _get_hf_mm_data( + self, + mm_items: MultiModalDataItems, + ) -> tuple[Mapping[str, object], Mapping[str, object]]: + """ + In contrast to the base class, this method always adds + `return_mm_token_type_ids` to the processor data + """ + processor_data, passthrough_data = super()._get_hf_mm_data(mm_items) + processor_data["return_mm_token_type_ids"] = True + return processor_data, passthrough_data + + def apply( + self, + prompt: str | list[int], + mm_data: MultiModalDataDict, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object] | None = None, + mm_uuids: MultiModalUUIDDict | None = None, + ) -> MultiModalInputs: + """ + Process multi-modal inputs to be used in vLLM. + + Apply HF Processor on prompt text and multi-modal data together, + outputting token IDs and processed tensors. + """ + if tokenization_kwargs is None: + tokenization_kwargs = {} + + mm_items = self._to_mm_items(mm_data) + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + if not isinstance(prompt, str): + # the prompt is the tokenized ids which is not supported + # by the hf_processor, which is why we would need to decode the ids + # into string + prompt = hf_processor.decode(prompt) + + # Bypass cached processor and always apply to the full set of mm inputs + # NOTE: we can't just set caching=False because base class method + # transforms outputs to `MultiModalKwargs` which is not going to + # work for Transformers. We have a lot of logic tied to + # `mm_tokens_per_modality` below + prompt_ids, processed_data, _ = self._apply_hf_processor_text_mm( + prompt_text=prompt, + mm_items=mm_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, + ) + + # For gemma3 we check `token_type_ids` as the key + token_type_key = ( + "mm_token_type_ids" + if "mm_token_type_ids" in processed_data + else "token_type_ids" + ) + mm_token_type_ids = processed_data.pop(token_type_key) + + # We can infer vLLM style placeholder from token type ids, if we split + # it for each input `mm_data`. + mm_positions = torch.where(mm_token_type_ids == 1)[1] + images = mm_items.get_items("image", ImageProcessorItems) + multimodal_config = self.info.ctx.model_config.multimodal_config + mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {} + image_sizes = [] + for item_idx in range(len(images)): + image_size = images.get_image_size(item_idx) + image_sizes.append((image_size.height, image_size.width)) + + mm_tokens_per_modality = hf_processor._get_num_multimodal_tokens( + image_sizes=image_sizes, **mm_processor_kwargs + ) + + mm_placeholders = {} + split_sizes = mm_tokens_per_modality["num_image_tokens"] + if split_sizes: + chunked_mm_positions = torch.split(mm_positions, split_sizes) + mm_tokens = torch.tensor(prompt_ids)[mm_token_type_ids[0].bool()] + chunked_mm_tokens = torch.split(mm_tokens, split_sizes) + ranges = [ + PlaceholderRange( + offset=positions[0].item(), + length=positions.shape[0], + is_embed=(mm_tokens == hf_processor.image_token_id).bool(), + ) + for positions, mm_tokens in zip(chunked_mm_positions, chunked_mm_tokens) + ] + mm_placeholders = {"image": ranges} + + processed_data["num_image_patches"] = torch.tensor( + mm_tokens_per_modality["num_image_patches"] + ) + mm_kwargs = MultiModalKwargsItems.from_hf_inputs( + processed_data, + self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs), + ) + + # Use overrides if provided; fallback to data-dependent hashing. + mm_hashes = self._hash_mm_items( + mm_items, hf_processor_mm_kwargs, tokenization_kwargs, mm_uuids=mm_uuids + ) + + return MultiModalInputs( + type="multimodal", + prompt_token_ids=prompt_ids, + mm_kwargs=mm_kwargs, + mm_hashes=mm_hashes, + mm_placeholders=mm_placeholders, + ) + + +class MultiModalMixin(SupportsMultiModal, SupportsMRoPE): + supports_multimodal_raw_input_only = True + merge_by_field_config = True + # Backwards compatibility for prev released models. State dicts back then + # had different formats and cannot be loaded with `AutoModel` mapping as is + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "language_model.model": "model.language_model", + "text_model.model": "model.text_model", + "vision_tower": "model.vision_tower", + "vqmodel": "model.vqmodel", + "visual": "model.visual", + "vision_model": "model.vision_model", + "vision_embed_tokens": "model.vision_embed_tokens", + "image_newline": "model.image_newline", + "multi_modal_projector": "model.multi_modal_projector", + "text_model.lm_head": "lm_head", + "language_model.lm_head": "lm_head", + # Qwen models used "model" as the name for the language model. + # Therefore, we must map each of submodule explicitly to avoid + # conflicts with newer models that use "model.language_model". + "model.embed_tokens": "model.language_model.embed_tokens", + "model.layers": "model.language_model.layers", + "model.norm": "model.language_model.norm", + } + ) + + def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""): + # Skip SupportsMRoPE.__init__ and call the next class in MRO + super(SupportsMRoPE, self).__init__(vllm_config=vllm_config, prefix=prefix) + + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: object, + ) -> torch.Tensor | IntermediateTensors: + # Gemma3 and PaliGemma needs `token_type_ids` to work correctly + # Other models will not have `token_type_ids` in kwargs + kwargs = {k: v for k, v in kwargs.items() if k == "token_type_ids"} + model_output = super().forward( + input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs + ) + return model_output + + def get_language_model(self) -> torch.nn.Module: + """Transformers backend multimodal classes do not contain a separate vLLM + language model class. Therefore, in order to return a language model vLLM class, + we use a wrapper to give `self` the same interface as a text model.""" + + # Exclude self and object + bases = self.__class__.mro()[1:-1] + # Keep only classes defined in `vllm.model_executor.models.transformers` + bases = [b for b in bases if ".transformers." in b.__module__] + # Exclude MultiModalMixin itself + bases = [b for b in bases if b is not MultiModalMixin] + + class LanguageModel(*bases): + def __init__(self, multimodal_model): + # Don't call super().__init__() to avoid re-initialization + self.__dict__.update(multimodal_model.__dict__) + + model = getattr_iter(self.model, ("language_model", "text_model"), None) + + return LanguageModel(self) + + def get_multimodal_embeddings(self, **kwargs): + pixel_values: torch.Tensor | None = kwargs.pop("pixel_values", None) + image_embeds: torch.Tensor | None = kwargs.pop("image_embeds", None) + # Model might use `image_patches` instead of `pixel_values` + if pixel_values is None: + pixel_values = kwargs.pop("image_patches", None) + + if image_embeds is not None: + return image_embeds + + if pixel_values is None: + return None + + num_image_patches = kwargs.pop("num_image_patches") + kwargs.pop("token_type_ids", None) # used only in `forward` + if pixel_values is not None: + vision_embeddings = self.model.get_image_features(pixel_values, **kwargs) + + if isinstance(vision_embeddings, torch.Tensor): + if vision_embeddings.ndim == 2: + vision_embeddings = vision_embeddings.unsqueeze(0) + + # Embeddings have to be 2D tensors of length `num_images` + # but transformers returns concat tensors if each patch + # is of different size. We split it back to make vLLM happy + vision_embeddings = torch.split( + vision_embeddings, num_image_patches.flatten().tolist() + ) + vision_embeddings = [ + embed.flatten(start_dim=0, end_dim=-2) + for embed in vision_embeddings + ] + + return vision_embeddings + + def get_mrope_input_positions( + self, + input_tokens: list[int], + hf_config: "PretrainedConfig", + image_grid_thw: list[list[int]] | torch.Tensor | None, + video_grid_thw: list[list[int]] | torch.Tensor | None, + second_per_grid_ts: list[float] | None = None, + context_len: int = 0, + seq_len: int | None = None, + audio_feature_lengths: torch.Tensor | None = None, + use_audio_in_video: bool = False, + ) -> tuple[torch.Tensor, int]: + if any((second_per_grid_ts, audio_feature_lengths, use_audio_in_video)): + raise NotImplementedError("Transformers backend only supports images.") + + if isinstance(image_grid_thw, list): + image_grid_thw = torch.tensor(image_grid_thw) + if isinstance(video_grid_thw, list): + video_grid_thw = torch.tensor(video_grid_thw) + + mrope_positions, mrope_position_delta = self.model.get_rope_index( + input_ids=torch.tensor(input_tokens).unsqueeze(0), + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + ) + + mrope_positions = mrope_positions[:, 0, context_len:seq_len] + mrope_position_delta = mrope_position_delta[0].item() + + return mrope_positions, mrope_position_delta diff --git a/vllm/model_executor/models/transformers/pooling.py b/vllm/model_executor/models/transformers/pooling.py new file mode 100644 index 0000000000..32aec49066 --- /dev/null +++ b/vllm/model_executor/models/transformers/pooling.py @@ -0,0 +1,118 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2024 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Transformers backend mixins for pooling models.""" + +from typing import TYPE_CHECKING + +import torch +from transformers import AutoModelForSequenceClassification + +from vllm.model_executor.layers.pooler import ( + ClassifierPooler, + CLSPool, + DispatchPooler, + Pooler, +) +from vllm.model_executor.models.interfaces import SupportsCrossEncoding +from vllm.model_executor.models.interfaces_base import VllmModelForPooling + +if TYPE_CHECKING: + from vllm.config import VllmConfig + + +class EmbeddingMixin(VllmModelForPooling): + default_pooling_type = "CLS" + + def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""): + # Skip VllmModelForPooling.__init__ and call the next class in MRO + super(VllmModelForPooling, self).__init__( + vllm_config=vllm_config, prefix=prefix + ) + + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + + self.pooler = DispatchPooler( + { + "token_embed": Pooler.for_token_embed(pooler_config), + "embed": Pooler.for_embed(pooler_config), + } + ) + + +class SequenceClassificationMixin(SupportsCrossEncoding, VllmModelForPooling): + default_pooling_type = "CLS" + + def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""): + # Skip VllmModelForPooling.__init__ and call the next class in MRO + super(VllmModelForPooling, self).__init__( + vllm_config=vllm_config, prefix=prefix + ) + + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + + # Certain information about the the model and classifier can only be + # inferred from the `ForSequenceClassification` class. Therefore, we + # instantiate it on the "meta" device to avoid allocating GPU memory. + with torch.device("meta"): + seq_cls_model = AutoModelForSequenceClassification.from_config( + self.config, + dtype=self.model_config.dtype, + trust_remote_code=self.model_config.trust_remote_code, + ) + + # When used for sequence classification, some models have their + # pooling layers removed. Make sure this is reflected in vLLM. + for module in seq_cls_model.modules(): + if hasattr(module, "pooler") and module.pooler is None: + self.model.pooler = None + break + if self.model.pooler is not None: + raise ValueError( + "Sequence classification models with pooling layers are not " + "supported yet in the Transformers backend." + ) + + # Unlike `lm_head`, `classifier` is not always `nn.Linear`. + self.classifier = seq_cls_model.classifier + self.init_parameters(self.classifier, dtype=self.model_config.head_dtype) + + class ClassifierWithReshape(self.classifier.__class__): + """CLSPool has already been applied in `pooling`. + Add dim to match expected input shape of `classifier.forward`.""" + + def forward(self, *args, **kwargs): + if len(args) > 0: + args = (args[0].unsqueeze(1), *args[1:]) + return super().forward(*args, **kwargs) + + self.classifier.__class__ = ClassifierWithReshape + + self.pooler = DispatchPooler( + { + "token_classify": Pooler.for_token_classify( + pooler_config, classifier=self.classifier + ), + "classify": ClassifierPooler( + pooling=CLSPool(), classifier=self.classifier, act_fn="classify" + ), + "score": ClassifierPooler( + pooling=CLSPool(), classifier=self.classifier, act_fn="score" + ), + } + ) diff --git a/vllm/model_executor/models/transformers/utils.py b/vllm/model_executor/models/transformers/utils.py new file mode 100644 index 0000000000..267a6e06e6 --- /dev/null +++ b/vllm/model_executor/models/transformers/utils.py @@ -0,0 +1,207 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2024 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Transformers backend utilities.""" + +from contextlib import contextmanager +from pathlib import Path +from typing import TYPE_CHECKING, Literal + +import torch +from torch import nn + +from vllm.config.utils import getattr_iter +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) + +if TYPE_CHECKING: + from vllm.config import VllmConfig + from vllm.model_executor.layers.quantization import QuantizationConfig + + +logger = init_logger(__name__) + + +# Copied from `accelerate` +@contextmanager +def init_on_device_without_buffers(device: torch.device): + """ + A context manager under which models are initialized with all + parameters on the specified device. However buffers are not + initialized on specified device. + + Args: + device (`torch.device`): + Device to initialize all parameters on. + """ + + old_register_parameter = nn.Module.register_parameter + + def register_empty_parameter(module, name, param): + old_register_parameter(module, name, param) + if param is not None: + param_cls = type(module._parameters[name]) + kwargs = module._parameters[name].__dict__ + kwargs["requires_grad"] = param.requires_grad + module._parameters[name] = param_cls( + module._parameters[name].to(device), **kwargs + ) + + tensor_constructors_to_patch = {} + + def patch_tensor_constructor(fn): + def wrapper(*args, **kwargs): + kwargs["device"] = device + return fn(*args, **kwargs) + + return wrapper + + try: + nn.Module.register_parameter = register_empty_parameter + for torch_function_name in tensor_constructors_to_patch: + setattr( + torch, + torch_function_name, + patch_tensor_constructor(getattr(torch, torch_function_name)), + ) + yield + finally: + nn.Module.register_parameter = old_register_parameter + for ( + torch_function_name, + old_torch_function, + ) in tensor_constructors_to_patch.items(): + setattr(torch, torch_function_name, old_torch_function) + + +Style = Literal["colwise", "colwise_rep", "rowwise", "rowwise_rep", "replicate"] + + +def replace_linear_class( + linear: nn.Linear, + style: Style = "replicate", + quant_config: "QuantizationConfig | None" = None, + *, + prefix: str = "", +) -> ColumnParallelLinear | RowParallelLinear | ReplicatedLinear: + """ + Replace nn.Linear with one of vLLM's tensor parallel linear classes. + + Args: + linear: `nn.Linear` to be replaced. + style: Tensor parallel style of the new linear, e.g. "colwise". + quant_config: Quantization config for the new linear. + Returns: + The new linear. + """ + + if not isinstance(style, str): + raise ValueError(f"Unsupported parallel style type {type(style)}, expected str") + + vllm_linear_cls, vllm_linear_kwargs = { + "colwise": (ColumnParallelLinear, {}), + "colwise_rep": (ColumnParallelLinear, {"gather_output": True}), + "rowwise": (RowParallelLinear, {}), + "rowwise_rep": (RowParallelLinear, {"input_is_parallel": False}), + "replicate": (ReplicatedLinear, {}), + }.get(style, (ReplicatedLinear, {})) + + return vllm_linear_cls( + input_size=linear.in_features, + output_size=linear.out_features, + bias=linear.bias is not None, + quant_config=quant_config, + prefix=prefix, + return_bias=False, + **vllm_linear_kwargs, + ) + + +def replace_rms_norm_class(rms_norm: nn.Module, hidden_size: int) -> RMSNorm: + """Replace a Transformers RMSNorm with vLLM's RMSNorm. + + This method assumes: + - Weight is stored as `weight`. + - Epsilon is stored as `eps` or `variance_epsilon`. + - `with_scale` indicates whether the layer has a weight (Gemma3n only). + - `var_hidden_size` is only ever used for Intern vision encoder in vLLM + and Transformers doesn't appear to have the same concept. + """ + eps = getattr_iter(rms_norm, ("eps", "variance_epsilon"), 1e-6) + kwargs = {"hidden_size": hidden_size, "eps": eps} + # Update hidden size if weight is available + weight_meta = getattr(rms_norm, "weight", None) + if weight_meta is not None: + kwargs["hidden_size"] = weight_meta.size(0) + # Check if weight is all zeros, which indicates GemmaRMSNorm + # We must create a new instance because rms_norm is on meta + try: + with torch.device("cpu"): + weight_test = getattr(rms_norm.__class__(1), "weight", None) + except Exception: + logger.warning( + "Failed to determine if RMSNorm weight is centered on zero or one. " + "Defaulting to one." + ) + weight_test = None + if weight_test is not None and torch.all(weight_test == 0): + return GemmaRMSNorm(**kwargs) + # Otherwise assume it's a regular RMSNorm + kwargs["has_weight"] = getattr(rms_norm, "with_scale", True) + if weight_meta is not None: + kwargs["dtype"] = weight_meta.dtype + else: + # No weight, fall back to weightless RMSNorm + kwargs["has_weight"] = False + return RMSNorm(**kwargs) + + +def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module): + logger.debug("%s: %s -> %s", name, old_module, new_module) + + +def get_feature_request_tip( + model: str, + trust_remote_code: bool, +) -> str: + hf_url = f"a discussion at https://huggingface.co/{model}/discussions/new" + gh_url = "an issue at https://github.com/huggingface/transformers/issues/new/choose" + url = hf_url if trust_remote_code else gh_url + prefix = f"Please open {url} to request support for this feature. " + if Path(model).exists(): + prefix = "" + doc_url = "https://docs.vllm.ai/en/latest/models/supported_models.html#writing-custom-models" + tip = f"See {doc_url} for instructions on how to add support yourself." + return f"{prefix}{tip}" + + +def can_enable_torch_compile(vllm_config: "VllmConfig") -> bool: + """ + Callable to be passed to `@support_torch_compile`'s `enable_if` argument. + + Defaults to `True` but is disabled in the following situations: + + - The model uses dynamic rope scaling. + """ + text_config = vllm_config.model_config.hf_config.get_text_config() + # Dynamic rope scaling is not compatible with torch.compile + rope_scaling: dict = getattr(text_config, "rope_scaling", None) or {} + return rope_scaling.get("rope_type") != "dynamic" diff --git a/vllm/model_executor/models/transformers_pooling.py b/vllm/model_executor/models/transformers_pooling.py deleted file mode 100644 index 7063a72748..0000000000 --- a/vllm/model_executor/models/transformers_pooling.py +++ /dev/null @@ -1,215 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# Copyright 2024 The vLLM team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Wrapper around `transformers` models for pooling tasks.""" - -import torch -from transformers import AutoModelForSequenceClassification - -from vllm.attention import Attention, AttentionType -from vllm.compilation.decorators import support_torch_compile -from vllm.config import VllmConfig -from vllm.model_executor.layers.pooler import ( - ClassifierPooler, - CLSPool, - DispatchPooler, - Pooler, -) -from vllm.sequence import IntermediateTensors - -from .interfaces_base import VllmModelForPooling -from .transformers import TransformersBase, can_enable_torch_compile -from .transformers_moe import TransformersMoEBase -from .utils import WeightsMapper - - -class TransformersPoolingBase(TransformersBase, VllmModelForPooling): - hf_to_vllm_mapper = WeightsMapper( - # These are applied in order, so the order matters! - orig_to_new_prefix={ - # Handle BERT-like models - "roberta": "model", - "bert": "model", - # Add `model.` prefix for base model checkpoints - "": "model.", - # Remove `model.` prefix if it was already there - "model.model.": "model.", - # Classifier/scoring heads will be adjacent to `model` - "model.score": "classifier", - "model.classifier": "classifier", - }, - orig_to_new_suffix={ - # Replace legacy suffixes used for norms - ".gamma": ".weight", - ".beta": ".bias", - }, - ) - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, prefix=prefix) - - # Skip unsupported/unwanted output embeddings layers - self.skip_prefixes.extend( - [ - "model.lm_head.", - "model.predictions.", - "model.qa_outputs.", - "model.embeddings_project.", - "model.discriminator_predictions.", - ] - ) - - # Some encoder models have the position_ids buffer in the checkpoint. - # vLLM will always pass position_ids as an argument, so we skip loading - # the buffer if it exists - self.skip_substrs.append("position_ids") - - # Some encoder models have the bias of the final classifier layer - # in the checkpoint. vLLM does not use this bias, so we skip loading - # it if it exists - self.skip_substrs.append("score.bias") - - # roberta-like models an extra padding in positions. - # FIXME(Isotr0py): This is quite hacky for roberta edge case, - # we should find a better way to handle this. - self.is_roberta = "roberta" in self.text_config.model_type - self.padding_idx = self.text_config.pad_token_id - - def create_attention_instances( - self, attn_type: AttentionType = AttentionType.DECODER - ) -> dict[int, Attention]: - # TODO(hmellor): Better way to detect encoder models - # In encoder models, the attention layers will have `is_causal=False` - is_encoder = lambda m: not getattr(m, "is_causal", True) - # vLLM does not support encoder-decoder models, so if any encoder layer - # is found, we assume the whole model is an encoder model - if any(is_encoder(m) for m in self.model.modules()): - attn_type = AttentionType.ENCODER_ONLY - - # Check minimum transformers version for encoder models support - if attn_type == AttentionType.ENCODER_ONLY: - self.check_version("4.57.0.dev0", "encoder models support") - - return super().create_attention_instances(attn_type) - - def forward( - self, - input_ids: torch.Tensor | None, - positions: torch.Tensor, - intermediate_tensors: IntermediateTensors | None = None, - inputs_embeds: torch.Tensor | None = None, - ) -> torch.Tensor | IntermediateTensors: - if self.is_roberta: - # RoBERTa-specific positions padding - positions += self.padding_idx + 1 - return super().forward( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - ) - - -@support_torch_compile(enable_if=can_enable_torch_compile) -class TransformersEmbeddingModel(TransformersPoolingBase): - default_pooling_type = "CLS" - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, prefix=prefix) - - pooler_config = vllm_config.model_config.pooler_config - assert pooler_config is not None - - self.pooler = DispatchPooler( - { - "token_embed": Pooler.for_token_embed(pooler_config), - "embed": Pooler.for_embed(pooler_config), - } - ) - - -@support_torch_compile(enable_if=can_enable_torch_compile) -class TransformersForSequenceClassification(TransformersPoolingBase): - default_pooling_type = "CLS" - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, prefix=prefix) - - pooler_config = vllm_config.model_config.pooler_config - assert pooler_config is not None - - # Certain information about the the model and classifier can only be - # inferred from the `ForSequenceClassification` class. Therefore, we - # instantiate it on the "meta" device to avoid allocating GPU memory. - with torch.device("meta"): - seq_cls_model = AutoModelForSequenceClassification.from_config( - self.config, - dtype=self.model_config.dtype, - trust_remote_code=self.model_config.trust_remote_code, - ) - - # When used for sequence classification, some models have their - # pooling layers removed. Make sure this is reflected in vLLM. - for module in seq_cls_model.modules(): - if hasattr(module, "pooler") and module.pooler is None: - self.model.pooler = None - break - if self.model.pooler is not None: - raise ValueError( - "Sequence classification models with pooling layers are not " - "supported yet in the Transformers backend." - ) - - # Unlike `lm_head`, `classifier` is not always `nn.Linear`. - self.classifier = seq_cls_model.classifier - self.init_parameters(self.classifier, dtype=self.model_config.head_dtype) - - class ClassifierWithReshape(self.classifier.__class__): - """CLSPool has already been applied in `pooling`. - Add dim to match expected input shape of `classifier.forward`.""" - - def forward(self, *args, **kwargs): - if len(args) > 0: - args = (args[0].unsqueeze(1), *args[1:]) - return super().forward(*args, **kwargs) - - self.classifier.__class__ = ClassifierWithReshape - - self.pooler = DispatchPooler( - { - "token_classify": Pooler.for_token_classify( - pooler_config, classifier=self.classifier - ), - "classify": ClassifierPooler( - pooling=CLSPool(), classifier=self.classifier, act_fn="classify" - ), - "score": ClassifierPooler( - pooling=CLSPool(), classifier=self.classifier, act_fn="score" - ), - } - ) - - -@support_torch_compile(enable_if=can_enable_torch_compile) -class TransformersMoEEmbeddingModel(TransformersMoEBase, TransformersEmbeddingModel): - pass - - -@support_torch_compile(enable_if=can_enable_torch_compile) -class TransformersMoEForSequenceClassification( - TransformersMoEBase, TransformersForSequenceClassification -): - pass