[Model] Explicit default_pooling_type
interface (#23736)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@ -28,8 +28,8 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata
|
|||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.tasks import PoolingTask
|
from vllm.tasks import PoolingTask
|
||||||
|
|
||||||
from .interfaces import (SupportsCrossEncoding, SupportsQuant,
|
from .interfaces import SupportsCrossEncoding, SupportsQuant
|
||||||
default_pooling_type)
|
from .interfaces_base import default_pooling_type
|
||||||
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
|
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
|
||||||
|
|
||||||
|
|
||||||
|
@ -27,13 +27,14 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
|||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding)
|
VocabParallelEmbedding)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.interfaces import (SupportsQuant,
|
|
||||||
default_pooling_type)
|
|
||||||
from vllm.model_executor.models.utils import WeightsMapper
|
from vllm.model_executor.models.utils import WeightsMapper
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
|
from .interfaces import SupportsQuant
|
||||||
|
from .interfaces_base import default_pooling_type
|
||||||
|
|
||||||
|
|
||||||
class BertWithRopeEmbedding(nn.Module):
|
class BertWithRopeEmbedding(nn.Module):
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@ from vllm.sequence import PoolerOutput
|
|||||||
from vllm.tasks import PoolingTask
|
from vllm.tasks import PoolingTask
|
||||||
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
||||||
|
|
||||||
from .interfaces import default_pooling_type
|
from .interfaces_base import default_pooling_type
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
from collections.abc import Iterable, Mapping, MutableSequence
|
from collections.abc import Iterable, Mapping, MutableSequence
|
||||||
from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol,
|
from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol,
|
||||||
TypeVar, Union, overload, runtime_checkable)
|
Union, overload, runtime_checkable)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -641,23 +641,6 @@ def supports_cross_encoding(
|
|||||||
return is_pooling_model(model) and _supports_cross_encoding(model)
|
return is_pooling_model(model) and _supports_cross_encoding(model)
|
||||||
|
|
||||||
|
|
||||||
_T = TypeVar("_T", bound=type[torch.nn.Module])
|
|
||||||
|
|
||||||
|
|
||||||
def default_pooling_type(pooling_type: str):
|
|
||||||
"""Set default_pooling_type decorator. """
|
|
||||||
|
|
||||||
def func(model: _T) -> _T:
|
|
||||||
model.default_pooling_type = pooling_type # type: ignore
|
|
||||||
return model
|
|
||||||
|
|
||||||
return func
|
|
||||||
|
|
||||||
|
|
||||||
def get_default_pooling_type(model: Union[type[object], object]) -> str:
|
|
||||||
return getattr(model, "default_pooling_type", "LAST")
|
|
||||||
|
|
||||||
|
|
||||||
class SupportsQuant:
|
class SupportsQuant:
|
||||||
"""The interface required for all models that support quantization."""
|
"""The interface required for all models that support quantization."""
|
||||||
|
|
||||||
|
@ -144,6 +144,17 @@ class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]):
|
|||||||
MRO of your model class.
|
MRO of your model class.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
default_pooling_type: ClassVar[str] = "LAST"
|
||||||
|
"""
|
||||||
|
Indicates the
|
||||||
|
[vllm.model_executor.layers.pooler.PoolerConfig.pooling_type][]
|
||||||
|
to use by default.
|
||||||
|
|
||||||
|
You can use the
|
||||||
|
[vllm.model_executor.models.interfaces_base.default_pooling_type][]
|
||||||
|
decorator to conveniently set this field.
|
||||||
|
"""
|
||||||
|
|
||||||
pooler: Pooler
|
pooler: Pooler
|
||||||
"""The pooler is only called on TP rank 0."""
|
"""The pooler is only called on TP rank 0."""
|
||||||
|
|
||||||
@ -165,3 +176,20 @@ def is_pooling_model(
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
return getattr(model, "is_pooling_model", False)
|
return getattr(model, "is_pooling_model", False)
|
||||||
|
|
||||||
|
|
||||||
|
_T = TypeVar("_T", bound=type[nn.Module])
|
||||||
|
|
||||||
|
|
||||||
|
def default_pooling_type(pooling_type: str):
|
||||||
|
"""Decorator to set `VllmModelForPooling.default_pooling_type`."""
|
||||||
|
|
||||||
|
def func(model: _T) -> _T:
|
||||||
|
model.default_pooling_type = pooling_type # type: ignore
|
||||||
|
return model
|
||||||
|
|
||||||
|
return func
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_pooling_type(model: Union[type[object], object]) -> str:
|
||||||
|
return getattr(model, "default_pooling_type", "LAST")
|
||||||
|
@ -31,7 +31,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import SupportsLoRA, SupportsPP, default_pooling_type
|
from .interfaces import SupportsLoRA, SupportsPP
|
||||||
|
from .interfaces_base import default_pooling_type
|
||||||
from .utils import (is_pp_missing_parameter,
|
from .utils import (is_pp_missing_parameter,
|
||||||
make_empty_intermediate_tensors_factory, make_layers,
|
make_empty_intermediate_tensors_factory, make_layers,
|
||||||
maybe_prefix)
|
maybe_prefix)
|
||||||
|
@ -26,7 +26,8 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata
|
|||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.tasks import PoolingTask
|
from vllm.tasks import PoolingTask
|
||||||
|
|
||||||
from .interfaces import SupportsCrossEncoding, default_pooling_type
|
from .interfaces import SupportsCrossEncoding
|
||||||
|
from .interfaces_base import default_pooling_type
|
||||||
from .utils import WeightsMapper, maybe_prefix
|
from .utils import WeightsMapper, maybe_prefix
|
||||||
|
|
||||||
|
|
||||||
|
@ -27,9 +27,6 @@ from transformers import BatchFeature
|
|||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
|
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.interfaces import (
|
|
||||||
IsAttentionFree, MultiModalEmbeddings, SupportsMultiModalWithRawInput,
|
|
||||||
default_pooling_type)
|
|
||||||
from vllm.model_executor.models.utils import AutoWeightsLoader
|
from vllm.model_executor.models.utils import AutoWeightsLoader
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.inputs import (ImageItem, ModalityData,
|
from vllm.multimodal.inputs import (ImageItem, ModalityData,
|
||||||
@ -43,6 +40,10 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
|
from .interfaces import (IsAttentionFree, MultiModalEmbeddings,
|
||||||
|
SupportsMultiModalWithRawInput)
|
||||||
|
from .interfaces_base import default_pooling_type
|
||||||
|
|
||||||
|
|
||||||
def _prithvi_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
def _prithvi_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
||||||
# This model receives in input a multi-dimensional tensor representing
|
# This model receives in input a multi-dimensional tensor representing
|
||||||
|
@ -18,7 +18,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|||||||
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
|
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import SupportsLoRA, SupportsPP, default_pooling_type
|
from .interfaces import SupportsLoRA, SupportsPP
|
||||||
|
from .interfaces_base import default_pooling_type
|
||||||
from .qwen2 import Qwen2Model
|
from .qwen2 import Qwen2Model
|
||||||
from .utils import AutoWeightsLoader, maybe_prefix
|
from .utils import AutoWeightsLoader, maybe_prefix
|
||||||
|
|
||||||
|
@ -25,11 +25,12 @@ from vllm.logger import init_logger
|
|||||||
from vllm.transformers_utils.dynamic_module import (
|
from vllm.transformers_utils.dynamic_module import (
|
||||||
try_get_class_from_dynamic_module)
|
try_get_class_from_dynamic_module)
|
||||||
|
|
||||||
from .interfaces import (get_default_pooling_type, has_inner_state, has_noops,
|
from .interfaces import (has_inner_state, has_noops, is_attention_free,
|
||||||
is_attention_free, is_hybrid, supports_cross_encoding,
|
is_hybrid, supports_cross_encoding,
|
||||||
supports_multimodal, supports_multimodal_raw_input,
|
supports_multimodal, supports_multimodal_raw_input,
|
||||||
supports_pp, supports_transcription, supports_v0_only)
|
supports_pp, supports_transcription, supports_v0_only)
|
||||||
from .interfaces_base import is_pooling_model, is_text_generation_model
|
from .interfaces_base import (get_default_pooling_type, is_pooling_model,
|
||||||
|
is_text_generation_model)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
@ -22,7 +22,8 @@ from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
|
|||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .bert_with_rope import BertWithRope, JinaRobertaModel
|
from .bert_with_rope import BertWithRope, JinaRobertaModel
|
||||||
from .interfaces import SupportsCrossEncoding, default_pooling_type
|
from .interfaces import SupportsCrossEncoding
|
||||||
|
from .interfaces_base import default_pooling_type
|
||||||
|
|
||||||
|
|
||||||
class RobertaEmbedding(nn.Module):
|
class RobertaEmbedding(nn.Module):
|
||||||
|
Reference in New Issue
Block a user