[Model] Systematic support for fp32 head, pooling models part (#23810)

Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
wang.yuqi
2025-09-09 22:29:50 +08:00
committed by GitHub
parent a55cf41a09
commit 19332c0479
14 changed files with 166 additions and 61 deletions

View File

@ -9,6 +9,7 @@ import mteb
import numpy as np
import pytest
import requests
import torch
from tests.models.utils import (EmbedModelInfo, RerankModelInfo,
check_embeddings_close)
@ -165,16 +166,19 @@ def mteb_test_embed_models(hf_runner,
vllm_extra_kwargs=None,
hf_model_callback=None,
atol=MTEB_EMBED_TOL):
# A model family has many models with the same architecture,
# and we don't need to test each one.
if not model_info.enable_test:
# A model family has many models with the same architecture,
# and we don't need to test each one.
pytest.skip("Skipping test.")
example_prompts = ["The chef prepared a delicious meal."]
# Test embed_dims, isnan and whether to use normalize
example_prompts = ["The chef prepared a delicious meal." * 1000]
# Allow vllm to test using the given dtype, such as float32
vllm_extra_kwargs = vllm_extra_kwargs or {}
vllm_extra_kwargs["dtype"] = model_info.dtype
# Allow vllm to test using hf_overrides
if model_info.hf_overrides is not None:
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
@ -186,21 +190,32 @@ def mteb_test_embed_models(hf_runner,
model_config = vllm_model.llm.llm_engine.model_config
# Confirm whether vllm is using the correct architecture
if model_info.architecture:
assert model_info.architecture in model_config.architectures
# Confirm whether vllm uses the correct default_pooling_type, which
# relates to whether chunked prefill and prefix caching are enabled
assert (model_config._model_info.default_pooling_type ==
model_info.default_pooling_type)
vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model),
MTEB_EMBED_TASKS)
vllm_dtype = vllm_model.llm.llm_engine.model_config.dtype
vllm_outputs = vllm_model.embed(example_prompts)
# Test embed_dims, isnan and whether to use normalize
vllm_outputs = vllm_model.embed(example_prompts,
truncate_prompt_tokens=-1)
assert not torch.any(torch.isnan(torch.tensor(vllm_outputs)))
# Accelerate mteb test by setting
# SentenceTransformers mteb score to a constant
if model_info.mteb_score is None:
with hf_runner(model_info.name,
is_sentence_transformer=True,
dtype="float32") as hf_model:
# e.g. setting default parameters for the encode method of hf_runner
if hf_model_callback is not None:
hf_model_callback(hf_model)
@ -299,14 +314,16 @@ def mteb_test_rerank_models(hf_runner,
hf_model_callback=None,
vllm_mteb_encoder=VllmMtebEncoder,
atol=MTEB_RERANK_TOL):
# A model family has many models with the same architecture,
# and we don't need to test each one.
if not model_info.enable_test:
# A model family has many models with the same architecture,
# and we don't need to test each one.
pytest.skip("Skipping test.")
# Allow vllm to test using the given dtype, such as float32
vllm_extra_kwargs = vllm_extra_kwargs or {}
vllm_extra_kwargs["dtype"] = model_info.dtype
# Allow vllm to test using hf_overrides
if model_info.hf_overrides is not None:
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
@ -319,9 +336,15 @@ def mteb_test_rerank_models(hf_runner,
model_config = vllm_model.llm.llm_engine.model_config
# Confirm whether vllm is using the correct architecture
if model_info.architecture:
assert (model_info.architecture in model_config.architectures)
# Score API is only enabled for num_labels == 1
assert model_config.hf_config.num_labels == 1
# Confirm whether vllm uses the correct default_pooling_type, which
# relates to whether chunked prefill and prefix caching are enabled
assert (model_config._model_info.default_pooling_type ==
model_info.default_pooling_type)
@ -330,6 +353,8 @@ def mteb_test_rerank_models(hf_runner,
languages=MTEB_RERANK_LANGS)
vllm_dtype = model_config.dtype
# Accelerate mteb test by setting
# SentenceTransformers mteb score to a constant
if model_info.mteb_score is None:
st_main_score, st_dtype = mteb_test_rerank_models_hf(
hf_runner, model_info.name, hf_model_callback)

View File

@ -14,6 +14,7 @@ from .mteb_utils import VllmMtebEncoder, mteb_test_rerank_models
RERANK_MODELS = [
LASTPoolingRerankModelInfo("BAAI/bge-reranker-v2-gemma",
architecture="GemmaForSequenceClassification",
mteb_score=0.33757,
hf_overrides={
"architectures":
["GemmaForSequenceClassification"],

View File

@ -745,7 +745,7 @@ class ModelConfig:
self.pooler_config = self._init_pooler_config()
self.dtype = _get_and_verify_dtype(
self.dtype: torch.dtype = _get_and_verify_dtype(
self.model,
self.hf_config,
self.dtype,
@ -1751,6 +1751,32 @@ class ModelConfig:
# `llm as reranker` models defaults to not using pad_token.
return getattr(self.hf_config, "use_pad_token", True)
@property
def head_dtype(self) -> torch.dtype:
"""
"head" refers to the last Linear layer(s) of an LLM,
such as the lm_head in a generation model,
or the score or classifier in a classification model.
The default head_dtype based on runner_type.\n
- The pooling model defaults to using fp32 head,
you can use --hf-overrides '{"head_dtype": "model"}' to disable it.\n
- The generate model defaults to not using fp32 head,
you can use --hf-overrides '{"head_dtype": "float32"}' to enable it.
"""
head_dtype = _get_head_dtype(config=self.hf_config,
dtype=self.dtype,
runner_type=self.runner_type)
if head_dtype not in current_platform.supported_dtypes:
logger.warning_once(
"The current platform does not support [%s] head dtype, "
"fallback to model dtype [%s].", head_dtype, self.dtype)
return self.dtype
logger.debug_once("head dtype: %s", head_dtype)
return head_dtype
def get_and_verify_max_len(self, max_model_len: int):
# Consider max_model_len in tokenizer_config only when
# pooling models use absolute position_embedding.
@ -2893,6 +2919,31 @@ def _get_and_verify_dtype(
return torch_dtype
def _get_head_dtype(config: PretrainedConfig, dtype: torch.dtype,
runner_type: str) -> torch.dtype:
head_dtype: Optional[Union[str,
torch.dtype]] = getattr(config, "head_dtype",
None)
if head_dtype == "model":
return dtype
elif isinstance(head_dtype, str):
head_dtype = head_dtype.lower()
if head_dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
raise ValueError(f"Unknown dtype: {head_dtype!r}")
return _STR_DTYPE_TO_TORCH_DTYPE[head_dtype]
elif isinstance(head_dtype, torch.dtype):
return head_dtype
elif head_dtype is None:
if torch.float32 not in current_platform.supported_dtypes:
return dtype
if runner_type == "pooling":
return torch.float32
return dtype
else:
raise ValueError(f"Unknown dtype: {head_dtype}")
def _get_and_verify_max_len(
hf_config: PretrainedConfig,
tokenizer_config: Optional[dict],

View File

@ -5,7 +5,7 @@ from collections.abc import Mapping, Set
from dataclasses import dataclass
from enum import IntEnum
from itertools import groupby
from typing import Callable, Optional, TypeVar, Union, cast
from typing import Callable, Optional, TypeVar, Union
import torch
import torch.nn as nn
@ -362,14 +362,13 @@ class PoolerIdentity(PoolerActivation):
class PoolerNormalize(PoolerActivation):
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
x = F.normalize(pooled_data.float(), p=2, dim=-1)
return x.to(pooled_data.dtype)
return F.normalize(pooled_data, p=2, dim=-1)
class PoolerMultiLabelClassify(PoolerActivation):
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
return F.sigmoid(pooled_data.float()).to(pooled_data.dtype)
return F.sigmoid(pooled_data)
class PoolerClassify(PoolerActivation):
@ -394,9 +393,9 @@ class PoolerClassify(PoolerActivation):
pooled_data.shape[-1])
if num_labels < 2:
return F.sigmoid(pooled_data.float()).to(pooled_data.dtype)
return F.sigmoid(pooled_data)
return F.softmax(pooled_data.float(), dim=-1).to(pooled_data.dtype)
return F.softmax(pooled_data, dim=-1)
class LambdaPoolerActivation(PoolerActivation):
@ -432,8 +431,9 @@ class EmbeddingPoolerHead(PoolerHead):
from vllm.model_executor.models.adapters import _load_st_projector
vllm_config = get_current_vllm_config()
self.projector = _load_st_projector(
self.projector: Optional[nn.Module] = _load_st_projector(
vllm_config.model_config) if vllm_config else None
self.head_dtype = vllm_config.model_config.head_dtype
def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
pooling_metadata: PoolingMetadata):
@ -442,16 +442,11 @@ class EmbeddingPoolerHead(PoolerHead):
pooled_data = torch.stack(pooled_data)
# pooled_data shape: [batchsize, hidden_dimension]
pooled_data = pooled_data.to(self.head_dtype)
# Apply ST projector
if self.projector is not None:
projector = cast(nn.Module, self.projector)
def _proj(x: torch.Tensor) -> torch.Tensor:
orig_dtype = x.dtype
y = projector(x.to(torch.float32))
return y.to(orig_dtype)
pooled_data = _proj(pooled_data)
pooled_data = self.projector(pooled_data)
# pooled_data shape: [batchsize, embedding_dimension]
pooling_params = get_pooling_params(pooling_metadata)
@ -494,8 +489,18 @@ class RewardPoolerHead(PoolerHead):
def __init__(self) -> None:
super().__init__(activation=PoolerClassify(static_num_labels=False))
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
self.head_dtype = vllm_config.model_config.head_dtype
def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
pooling_metadata: PoolingMetadata):
if isinstance(pooled_data, list):
pooled_data = [p.to(self.head_dtype) for p in pooled_data]
else:
pooled_data = pooled_data.to(self.head_dtype)
pooling_params = get_pooling_params(pooling_metadata)
# for softmax
@ -641,6 +646,7 @@ class ClassifierPooler(Pooler):
self.act_fn = act_fn or PoolerClassify()
self.logit_bias: Optional[
float] = vllm_config.model_config.pooler_config.logit_bias
self.head_dtype = vllm_config.model_config.head_dtype
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"classify", "score"}
@ -655,6 +661,8 @@ class ClassifierPooler(Pooler):
pooled_data = torch.stack(pooled_data)
# pooled_data shape: [batchsize, hidden_size]
pooled_data = pooled_data.to(self.head_dtype)
if self.classifier is not None:
pooled_data = self.classifier(pooled_data)
# pooled_data shape: [batchsize, num_labels]

View File

@ -62,7 +62,7 @@ def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]:
linear = nn.Linear(layer_config.get("in_features", 768),
layer_config.get("out_features", 768),
bias=layer_config.get("bias", True),
dtype=torch.float32)
dtype=model_config.head_dtype)
if not _load_dense_weights(linear, folder, model_config):
continue
@ -70,7 +70,7 @@ def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]:
layers.append(linear)
if act_name := layer_config.get("activation_function"):
layers.append(get_act_fn(act_name))
return nn.Sequential(*layers).to(dtype=torch.float32)
return nn.Sequential(*layers).to(dtype=model_config.head_dtype)
except Exception:
logger.exception("ST projector loading failed")
@ -105,15 +105,13 @@ def _load_dense_weights(linear: nn.Linear, folder: str,
if weight_key in state_dict:
weight_loader = getattr(linear.weight, "weight_loader",
default_weight_loader)
weight_loader(linear.weight,
state_dict[weight_key].to(torch.float32))
weight_loader(linear.weight, state_dict[weight_key])
bias_key = weight_key.replace("weight", "bias")
if linear.bias is not None and bias_key in state_dict:
bias_loader = getattr(linear.bias, "weight_loader",
default_weight_loader)
bias_loader(linear.bias,
state_dict[bias_key].to(torch.float32))
bias_loader(linear.bias, state_dict[bias_key])
return True
except Exception:
logger.exception("Failed to load %s", filename)

View File

@ -562,7 +562,9 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
self.bert = BertPoolingModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "bert"),
embedding_class=BertEmbedding)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.classifier = nn.Linear(config.hidden_size,
config.num_labels,
dtype=vllm_config.model_config.head_dtype)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None

View File

@ -637,14 +637,14 @@ class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding):
self.new = GteNewModel(vllm_config=vllm_config,
prefix=prefix,
add_pooling_layer=True)
self.classifier = RowParallelLinear(config.hidden_size,
config.num_labels,
input_is_parallel=False,
bias=True,
quant_config=quant_config,
prefix=maybe_prefix(
prefix, "classifier"),
return_bias=False)
self.classifier = ReplicatedLinear(
config.hidden_size,
config.num_labels,
bias=True,
quant_config=quant_config,
params_dtype=vllm_config.model_config.head_dtype,
prefix=maybe_prefix(prefix, "classifier"),
return_bias=False)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None

View File

@ -339,7 +339,10 @@ class GPT2ForSequenceClassification(nn.Module):
config = vllm_config.model_config.hf_config
self.transformer = GPT2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "gpt2"))
self.score = nn.Linear(config.n_embd, config.num_labels, bias=False)
self.score = nn.Linear(config.n_embd,
config.num_labels,
bias=False,
dtype=vllm_config.model_config.head_dtype)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
@ -348,7 +351,7 @@ class GPT2ForSequenceClassification(nn.Module):
"encode":
Pooler.for_encode(pooler_config),
"classify":
Pooler.for_classify(pooler_config, classifier=None),
Pooler.for_classify(pooler_config, classifier=self.score),
})
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
@ -367,8 +370,7 @@ class GPT2ForSequenceClassification(nn.Module):
position_ids=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors)
logits = self.score(hidden_states)
return logits
return hidden_states
def _add_transformer_prefix(

View File

@ -423,13 +423,15 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
delattr(self, attr)
config = vllm_config.model_config.hf_config
self.v_head = RowParallelLinear(
config.hidden_size,
1,
bias=False,
input_is_parallel=False,
prefix=maybe_prefix(prefix, "v_head"),
)
self.head_dtype = vllm_config.model_config.head_dtype
self.v_head = RowParallelLinear(config.hidden_size,
1,
bias=False,
input_is_parallel=False,
params_dtype=self.head_dtype,
prefix=maybe_prefix(prefix, "v_head"),
return_bias=False)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
@ -446,5 +448,6 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
logits, _ = self.v_head(hidden_states)
hidden_states = hidden_states.to(self.head_dtype)
logits = self.v_head(hidden_states)
return logits

View File

@ -613,7 +613,7 @@ class JambaForSequenceClassification(JambaForCausalLM):
config.hidden_size,
num_labels,
bias=score_bias,
dtype=torch.float32,
dtype=vllm_config.model_config.head_dtype,
)
pooler_config = vllm_config.model_config.pooler_config

View File

@ -5,9 +5,9 @@ from typing import Optional
import torch
import torch.nn as nn
from transformers import BatchFeature, PretrainedConfig
from transformers import BatchFeature
from vllm.config import VllmConfig
from vllm.config import ModelConfig, VllmConfig
from vllm.inputs import TokensPrompt
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -28,13 +28,17 @@ logger = init_logger(__name__)
class JinaVLScorer(nn.Module):
def __init__(self, config: PretrainedConfig):
def __init__(self, model_config: "ModelConfig"):
super().__init__()
config = model_config.hf_config
head_dtype = model_config.head_dtype
self.dense = ColumnParallelLinear(config.hidden_size,
config.hidden_size,
params_dtype=head_dtype,
bias=True)
self.out_proj = RowParallelLinear(config.hidden_size,
config.num_labels,
params_dtype=head_dtype,
bias=True)
def forward(self, x, **kwargs):
@ -88,11 +92,10 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration,
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "qwen2_vl"))
config = vllm_config.model_config.hf_config
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.score = JinaVLScorer(config)
self.score = JinaVLScorer(vllm_config.model_config)
self.pooler = DispatchPooler({
"encode":
Pooler.for_encode(pooler_config),

View File

@ -306,7 +306,9 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
self.config = config
self.model = ModernBertModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "modernbert"))
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.classifier = nn.Linear(config.hidden_size,
config.num_labels,
dtype=vllm_config.model_config.head_dtype)
self.pooling = ModernBertPooler(config)
pooler_config = vllm_config.model_config.pooler_config

View File

@ -53,15 +53,18 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
self.quant_config = quant_config
self.model = Qwen2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.head_dtype = vllm_config.model_config.head_dtype
self.score = nn.Sequential(
ColumnParallelLinear(config.hidden_size,
config.hidden_size,
quant_config=quant_config,
params_dtype=self.head_dtype,
return_bias=False),
nn.ReLU(),
RowParallelLinear(config.hidden_size,
config.num_labels,
params_dtype=self.head_dtype,
quant_config=quant_config,
return_bias=False),
)
@ -80,6 +83,7 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
hidden_states = hidden_states.to(self.head_dtype)
logits = self.score(hidden_states)
return logits

View File

@ -8,7 +8,7 @@ import torch
from torch import nn
from transformers import RobertaConfig
from vllm.config import VllmConfig
from vllm.config import ModelConfig, VllmConfig
from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool,
DispatchPooler, Pooler)
from vllm.model_executor.layers.vocab_parallel_embedding import (
@ -73,10 +73,16 @@ class RobertaEmbedding(nn.Module):
class RobertaClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(self, config: RobertaConfig):
def __init__(self, model_config: "ModelConfig"):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
config = model_config.hf_config
head_dtype = model_config.head_dtype
self.dense = nn.Linear(config.hidden_size,
config.hidden_size,
dtype=head_dtype)
self.out_proj = nn.Linear(config.hidden_size,
config.num_labels,
dtype=head_dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# CLSPool has already been applied in `pooling`
@ -184,7 +190,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
self.roberta = BertModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "bert"),
embedding_class=RobertaEmbedding)
self.classifier = RobertaClassificationHead(config)
self.classifier = RobertaClassificationHead(vllm_config.model_config)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None