[Model] Explicit default_pooling_type interface (#23736)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-08-27 21:24:09 +08:00
committed by GitHub
parent 704432af3c
commit 5eeef1b908
11 changed files with 51 additions and 33 deletions

View File

@ -28,8 +28,8 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.tasks import PoolingTask from vllm.tasks import PoolingTask
from .interfaces import (SupportsCrossEncoding, SupportsQuant, from .interfaces import SupportsCrossEncoding, SupportsQuant
default_pooling_type) from .interfaces_base import default_pooling_type
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix

View File

@ -27,13 +27,14 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import (SupportsQuant,
default_pooling_type)
from vllm.model_executor.models.utils import WeightsMapper from vllm.model_executor.models.utils import WeightsMapper
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsQuant
from .interfaces_base import default_pooling_type
class BertWithRopeEmbedding(nn.Module): class BertWithRopeEmbedding(nn.Module):

View File

@ -20,7 +20,7 @@ from vllm.sequence import PoolerOutput
from vllm.tasks import PoolingTask from vllm.tasks import PoolingTask
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from .interfaces import default_pooling_type from .interfaces_base import default_pooling_type
logger = init_logger(__name__) logger = init_logger(__name__)

View File

@ -3,7 +3,7 @@
from collections.abc import Iterable, Mapping, MutableSequence from collections.abc import Iterable, Mapping, MutableSequence
from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol, from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol,
TypeVar, Union, overload, runtime_checkable) Union, overload, runtime_checkable)
import numpy as np import numpy as np
import torch import torch
@ -641,23 +641,6 @@ def supports_cross_encoding(
return is_pooling_model(model) and _supports_cross_encoding(model) return is_pooling_model(model) and _supports_cross_encoding(model)
_T = TypeVar("_T", bound=type[torch.nn.Module])
def default_pooling_type(pooling_type: str):
"""Set default_pooling_type decorator. """
def func(model: _T) -> _T:
model.default_pooling_type = pooling_type # type: ignore
return model
return func
def get_default_pooling_type(model: Union[type[object], object]) -> str:
return getattr(model, "default_pooling_type", "LAST")
class SupportsQuant: class SupportsQuant:
"""The interface required for all models that support quantization.""" """The interface required for all models that support quantization."""

View File

@ -144,6 +144,17 @@ class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]):
MRO of your model class. MRO of your model class.
""" """
default_pooling_type: ClassVar[str] = "LAST"
"""
Indicates the
[vllm.model_executor.layers.pooler.PoolerConfig.pooling_type][]
to use by default.
You can use the
[vllm.model_executor.models.interfaces_base.default_pooling_type][]
decorator to conveniently set this field.
"""
pooler: Pooler pooler: Pooler
"""The pooler is only called on TP rank 0.""" """The pooler is only called on TP rank 0."""
@ -165,3 +176,20 @@ def is_pooling_model(
return False return False
return getattr(model, "is_pooling_model", False) return getattr(model, "is_pooling_model", False)
_T = TypeVar("_T", bound=type[nn.Module])
def default_pooling_type(pooling_type: str):
"""Decorator to set `VllmModelForPooling.default_pooling_type`."""
def func(model: _T) -> _T:
model.default_pooling_type = pooling_type # type: ignore
return model
return func
def get_default_pooling_type(model: Union[type[object], object]) -> str:
return getattr(model, "default_pooling_type", "LAST")

View File

@ -31,7 +31,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP, default_pooling_type from .interfaces import SupportsLoRA, SupportsPP
from .interfaces_base import default_pooling_type
from .utils import (is_pp_missing_parameter, from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix) maybe_prefix)

View File

@ -26,7 +26,8 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.tasks import PoolingTask from vllm.tasks import PoolingTask
from .interfaces import SupportsCrossEncoding, default_pooling_type from .interfaces import SupportsCrossEncoding
from .interfaces_base import default_pooling_type
from .utils import WeightsMapper, maybe_prefix from .utils import WeightsMapper, maybe_prefix

View File

@ -27,9 +27,6 @@ from transformers import BatchFeature
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import (
IsAttentionFree, MultiModalEmbeddings, SupportsMultiModalWithRawInput,
default_pooling_type)
from vllm.model_executor.models.utils import AutoWeightsLoader from vllm.model_executor.models.utils import AutoWeightsLoader
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (ImageItem, ModalityData, from vllm.multimodal.inputs import (ImageItem, ModalityData,
@ -43,6 +40,10 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import (IsAttentionFree, MultiModalEmbeddings,
SupportsMultiModalWithRawInput)
from .interfaces_base import default_pooling_type
def _prithvi_field_config(hf_inputs: Mapping[str, torch.Tensor]): def _prithvi_field_config(hf_inputs: Mapping[str, torch.Tensor]):
# This model receives in input a multi-dimensional tensor representing # This model receives in input a multi-dimensional tensor representing

View File

@ -18,7 +18,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP, default_pooling_type from .interfaces import SupportsLoRA, SupportsPP
from .interfaces_base import default_pooling_type
from .qwen2 import Qwen2Model from .qwen2 import Qwen2Model
from .utils import AutoWeightsLoader, maybe_prefix from .utils import AutoWeightsLoader, maybe_prefix

View File

@ -25,11 +25,12 @@ from vllm.logger import init_logger
from vllm.transformers_utils.dynamic_module import ( from vllm.transformers_utils.dynamic_module import (
try_get_class_from_dynamic_module) try_get_class_from_dynamic_module)
from .interfaces import (get_default_pooling_type, has_inner_state, has_noops, from .interfaces import (has_inner_state, has_noops, is_attention_free,
is_attention_free, is_hybrid, supports_cross_encoding, is_hybrid, supports_cross_encoding,
supports_multimodal, supports_multimodal_raw_input, supports_multimodal, supports_multimodal_raw_input,
supports_pp, supports_transcription, supports_v0_only) supports_pp, supports_transcription, supports_v0_only)
from .interfaces_base import is_pooling_model, is_text_generation_model from .interfaces_base import (get_default_pooling_type, is_pooling_model,
is_text_generation_model)
logger = init_logger(__name__) logger = init_logger(__name__)

View File

@ -22,7 +22,8 @@ from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .bert_with_rope import BertWithRope, JinaRobertaModel from .bert_with_rope import BertWithRope, JinaRobertaModel
from .interfaces import SupportsCrossEncoding, default_pooling_type from .interfaces import SupportsCrossEncoding
from .interfaces_base import default_pooling_type
class RobertaEmbedding(nn.Module): class RobertaEmbedding(nn.Module):