[Model] Explicit default_pooling_type
interface (#23736)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@ -28,8 +28,8 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tasks import PoolingTask
|
||||
|
||||
from .interfaces import (SupportsCrossEncoding, SupportsQuant,
|
||||
default_pooling_type)
|
||||
from .interfaces import SupportsCrossEncoding, SupportsQuant
|
||||
from .interfaces_base import default_pooling_type
|
||||
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
|
||||
|
||||
|
||||
|
@ -27,13 +27,14 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import (SupportsQuant,
|
||||
default_pooling_type)
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsQuant
|
||||
from .interfaces_base import default_pooling_type
|
||||
|
||||
|
||||
class BertWithRopeEmbedding(nn.Module):
|
||||
|
||||
|
@ -20,7 +20,7 @@ from vllm.sequence import PoolerOutput
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
||||
|
||||
from .interfaces import default_pooling_type
|
||||
from .interfaces_base import default_pooling_type
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
from collections.abc import Iterable, Mapping, MutableSequence
|
||||
from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol,
|
||||
TypeVar, Union, overload, runtime_checkable)
|
||||
Union, overload, runtime_checkable)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -641,23 +641,6 @@ def supports_cross_encoding(
|
||||
return is_pooling_model(model) and _supports_cross_encoding(model)
|
||||
|
||||
|
||||
_T = TypeVar("_T", bound=type[torch.nn.Module])
|
||||
|
||||
|
||||
def default_pooling_type(pooling_type: str):
|
||||
"""Set default_pooling_type decorator. """
|
||||
|
||||
def func(model: _T) -> _T:
|
||||
model.default_pooling_type = pooling_type # type: ignore
|
||||
return model
|
||||
|
||||
return func
|
||||
|
||||
|
||||
def get_default_pooling_type(model: Union[type[object], object]) -> str:
|
||||
return getattr(model, "default_pooling_type", "LAST")
|
||||
|
||||
|
||||
class SupportsQuant:
|
||||
"""The interface required for all models that support quantization."""
|
||||
|
||||
|
@ -144,6 +144,17 @@ class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]):
|
||||
MRO of your model class.
|
||||
"""
|
||||
|
||||
default_pooling_type: ClassVar[str] = "LAST"
|
||||
"""
|
||||
Indicates the
|
||||
[vllm.model_executor.layers.pooler.PoolerConfig.pooling_type][]
|
||||
to use by default.
|
||||
|
||||
You can use the
|
||||
[vllm.model_executor.models.interfaces_base.default_pooling_type][]
|
||||
decorator to conveniently set this field.
|
||||
"""
|
||||
|
||||
pooler: Pooler
|
||||
"""The pooler is only called on TP rank 0."""
|
||||
|
||||
@ -165,3 +176,20 @@ def is_pooling_model(
|
||||
return False
|
||||
|
||||
return getattr(model, "is_pooling_model", False)
|
||||
|
||||
|
||||
_T = TypeVar("_T", bound=type[nn.Module])
|
||||
|
||||
|
||||
def default_pooling_type(pooling_type: str):
|
||||
"""Decorator to set `VllmModelForPooling.default_pooling_type`."""
|
||||
|
||||
def func(model: _T) -> _T:
|
||||
model.default_pooling_type = pooling_type # type: ignore
|
||||
return model
|
||||
|
||||
return func
|
||||
|
||||
|
||||
def get_default_pooling_type(model: Union[type[object], object]) -> str:
|
||||
return getattr(model, "default_pooling_type", "LAST")
|
||||
|
@ -31,7 +31,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP, default_pooling_type
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .interfaces_base import default_pooling_type
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
@ -26,7 +26,8 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tasks import PoolingTask
|
||||
|
||||
from .interfaces import SupportsCrossEncoding, default_pooling_type
|
||||
from .interfaces import SupportsCrossEncoding
|
||||
from .interfaces_base import default_pooling_type
|
||||
from .utils import WeightsMapper, maybe_prefix
|
||||
|
||||
|
||||
|
@ -27,9 +27,6 @@ from transformers import BatchFeature
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import (
|
||||
IsAttentionFree, MultiModalEmbeddings, SupportsMultiModalWithRawInput,
|
||||
default_pooling_type)
|
||||
from vllm.model_executor.models.utils import AutoWeightsLoader
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (ImageItem, ModalityData,
|
||||
@ -43,6 +40,10 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import (IsAttentionFree, MultiModalEmbeddings,
|
||||
SupportsMultiModalWithRawInput)
|
||||
from .interfaces_base import default_pooling_type
|
||||
|
||||
|
||||
def _prithvi_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
||||
# This model receives in input a multi-dimensional tensor representing
|
||||
|
@ -18,7 +18,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP, default_pooling_type
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .interfaces_base import default_pooling_type
|
||||
from .qwen2 import Qwen2Model
|
||||
from .utils import AutoWeightsLoader, maybe_prefix
|
||||
|
||||
|
@ -25,11 +25,12 @@ from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.dynamic_module import (
|
||||
try_get_class_from_dynamic_module)
|
||||
|
||||
from .interfaces import (get_default_pooling_type, has_inner_state, has_noops,
|
||||
is_attention_free, is_hybrid, supports_cross_encoding,
|
||||
from .interfaces import (has_inner_state, has_noops, is_attention_free,
|
||||
is_hybrid, supports_cross_encoding,
|
||||
supports_multimodal, supports_multimodal_raw_input,
|
||||
supports_pp, supports_transcription, supports_v0_only)
|
||||
from .interfaces_base import is_pooling_model, is_text_generation_model
|
||||
from .interfaces_base import (get_default_pooling_type, is_pooling_model,
|
||||
is_text_generation_model)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
@ -22,7 +22,8 @@ from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .bert_with_rope import BertWithRope, JinaRobertaModel
|
||||
from .interfaces import SupportsCrossEncoding, default_pooling_type
|
||||
from .interfaces import SupportsCrossEncoding
|
||||
from .interfaces_base import default_pooling_type
|
||||
|
||||
|
||||
class RobertaEmbedding(nn.Module):
|
||||
|
Reference in New Issue
Block a user