mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Model] Systematic support for fp32 head, pooling models part (#23810)
Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
@ -9,6 +9,7 @@ import mteb
|
||||
import numpy as np
|
||||
import pytest
|
||||
import requests
|
||||
import torch
|
||||
|
||||
from tests.models.utils import (EmbedModelInfo, RerankModelInfo,
|
||||
check_embeddings_close)
|
||||
@ -165,16 +166,19 @@ def mteb_test_embed_models(hf_runner,
|
||||
vllm_extra_kwargs=None,
|
||||
hf_model_callback=None,
|
||||
atol=MTEB_EMBED_TOL):
|
||||
# A model family has many models with the same architecture,
|
||||
# and we don't need to test each one.
|
||||
if not model_info.enable_test:
|
||||
# A model family has many models with the same architecture,
|
||||
# and we don't need to test each one.
|
||||
pytest.skip("Skipping test.")
|
||||
|
||||
example_prompts = ["The chef prepared a delicious meal."]
|
||||
# Test embed_dims, isnan and whether to use normalize
|
||||
example_prompts = ["The chef prepared a delicious meal." * 1000]
|
||||
|
||||
# Allow vllm to test using the given dtype, such as float32
|
||||
vllm_extra_kwargs = vllm_extra_kwargs or {}
|
||||
vllm_extra_kwargs["dtype"] = model_info.dtype
|
||||
|
||||
# Allow vllm to test using hf_overrides
|
||||
if model_info.hf_overrides is not None:
|
||||
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
|
||||
|
||||
@ -186,21 +190,32 @@ def mteb_test_embed_models(hf_runner,
|
||||
|
||||
model_config = vllm_model.llm.llm_engine.model_config
|
||||
|
||||
# Confirm whether vllm is using the correct architecture
|
||||
if model_info.architecture:
|
||||
assert model_info.architecture in model_config.architectures
|
||||
|
||||
# Confirm whether vllm uses the correct default_pooling_type, which
|
||||
# relates to whether chunked prefill and prefix caching are enabled
|
||||
assert (model_config._model_info.default_pooling_type ==
|
||||
model_info.default_pooling_type)
|
||||
|
||||
vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model),
|
||||
MTEB_EMBED_TASKS)
|
||||
vllm_dtype = vllm_model.llm.llm_engine.model_config.dtype
|
||||
vllm_outputs = vllm_model.embed(example_prompts)
|
||||
|
||||
# Test embed_dims, isnan and whether to use normalize
|
||||
vllm_outputs = vllm_model.embed(example_prompts,
|
||||
truncate_prompt_tokens=-1)
|
||||
assert not torch.any(torch.isnan(torch.tensor(vllm_outputs)))
|
||||
|
||||
# Accelerate mteb test by setting
|
||||
# SentenceTransformers mteb score to a constant
|
||||
if model_info.mteb_score is None:
|
||||
with hf_runner(model_info.name,
|
||||
is_sentence_transformer=True,
|
||||
dtype="float32") as hf_model:
|
||||
|
||||
# e.g. setting default parameters for the encode method of hf_runner
|
||||
if hf_model_callback is not None:
|
||||
hf_model_callback(hf_model)
|
||||
|
||||
@ -299,14 +314,16 @@ def mteb_test_rerank_models(hf_runner,
|
||||
hf_model_callback=None,
|
||||
vllm_mteb_encoder=VllmMtebEncoder,
|
||||
atol=MTEB_RERANK_TOL):
|
||||
# A model family has many models with the same architecture,
|
||||
# and we don't need to test each one.
|
||||
if not model_info.enable_test:
|
||||
# A model family has many models with the same architecture,
|
||||
# and we don't need to test each one.
|
||||
pytest.skip("Skipping test.")
|
||||
|
||||
# Allow vllm to test using the given dtype, such as float32
|
||||
vllm_extra_kwargs = vllm_extra_kwargs or {}
|
||||
vllm_extra_kwargs["dtype"] = model_info.dtype
|
||||
|
||||
# Allow vllm to test using hf_overrides
|
||||
if model_info.hf_overrides is not None:
|
||||
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
|
||||
|
||||
@ -319,9 +336,15 @@ def mteb_test_rerank_models(hf_runner,
|
||||
|
||||
model_config = vllm_model.llm.llm_engine.model_config
|
||||
|
||||
# Confirm whether vllm is using the correct architecture
|
||||
if model_info.architecture:
|
||||
assert (model_info.architecture in model_config.architectures)
|
||||
|
||||
# Score API is only enabled for num_labels == 1
|
||||
assert model_config.hf_config.num_labels == 1
|
||||
|
||||
# Confirm whether vllm uses the correct default_pooling_type, which
|
||||
# relates to whether chunked prefill and prefix caching are enabled
|
||||
assert (model_config._model_info.default_pooling_type ==
|
||||
model_info.default_pooling_type)
|
||||
|
||||
@ -330,6 +353,8 @@ def mteb_test_rerank_models(hf_runner,
|
||||
languages=MTEB_RERANK_LANGS)
|
||||
vllm_dtype = model_config.dtype
|
||||
|
||||
# Accelerate mteb test by setting
|
||||
# SentenceTransformers mteb score to a constant
|
||||
if model_info.mteb_score is None:
|
||||
st_main_score, st_dtype = mteb_test_rerank_models_hf(
|
||||
hf_runner, model_info.name, hf_model_callback)
|
||||
|
@ -14,6 +14,7 @@ from .mteb_utils import VllmMtebEncoder, mteb_test_rerank_models
|
||||
RERANK_MODELS = [
|
||||
LASTPoolingRerankModelInfo("BAAI/bge-reranker-v2-gemma",
|
||||
architecture="GemmaForSequenceClassification",
|
||||
mteb_score=0.33757,
|
||||
hf_overrides={
|
||||
"architectures":
|
||||
["GemmaForSequenceClassification"],
|
||||
|
@ -745,7 +745,7 @@ class ModelConfig:
|
||||
|
||||
self.pooler_config = self._init_pooler_config()
|
||||
|
||||
self.dtype = _get_and_verify_dtype(
|
||||
self.dtype: torch.dtype = _get_and_verify_dtype(
|
||||
self.model,
|
||||
self.hf_config,
|
||||
self.dtype,
|
||||
@ -1751,6 +1751,32 @@ class ModelConfig:
|
||||
# `llm as reranker` models defaults to not using pad_token.
|
||||
return getattr(self.hf_config, "use_pad_token", True)
|
||||
|
||||
@property
|
||||
def head_dtype(self) -> torch.dtype:
|
||||
"""
|
||||
"head" refers to the last Linear layer(s) of an LLM,
|
||||
such as the lm_head in a generation model,
|
||||
or the score or classifier in a classification model.
|
||||
|
||||
The default head_dtype based on runner_type.\n
|
||||
- The pooling model defaults to using fp32 head,
|
||||
you can use --hf-overrides '{"head_dtype": "model"}' to disable it.\n
|
||||
- The generate model defaults to not using fp32 head,
|
||||
you can use --hf-overrides '{"head_dtype": "float32"}' to enable it.
|
||||
"""
|
||||
head_dtype = _get_head_dtype(config=self.hf_config,
|
||||
dtype=self.dtype,
|
||||
runner_type=self.runner_type)
|
||||
|
||||
if head_dtype not in current_platform.supported_dtypes:
|
||||
logger.warning_once(
|
||||
"The current platform does not support [%s] head dtype, "
|
||||
"fallback to model dtype [%s].", head_dtype, self.dtype)
|
||||
return self.dtype
|
||||
|
||||
logger.debug_once("head dtype: %s", head_dtype)
|
||||
return head_dtype
|
||||
|
||||
def get_and_verify_max_len(self, max_model_len: int):
|
||||
# Consider max_model_len in tokenizer_config only when
|
||||
# pooling models use absolute position_embedding.
|
||||
@ -2893,6 +2919,31 @@ def _get_and_verify_dtype(
|
||||
return torch_dtype
|
||||
|
||||
|
||||
def _get_head_dtype(config: PretrainedConfig, dtype: torch.dtype,
|
||||
runner_type: str) -> torch.dtype:
|
||||
head_dtype: Optional[Union[str,
|
||||
torch.dtype]] = getattr(config, "head_dtype",
|
||||
None)
|
||||
|
||||
if head_dtype == "model":
|
||||
return dtype
|
||||
elif isinstance(head_dtype, str):
|
||||
head_dtype = head_dtype.lower()
|
||||
if head_dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
|
||||
raise ValueError(f"Unknown dtype: {head_dtype!r}")
|
||||
return _STR_DTYPE_TO_TORCH_DTYPE[head_dtype]
|
||||
elif isinstance(head_dtype, torch.dtype):
|
||||
return head_dtype
|
||||
elif head_dtype is None:
|
||||
if torch.float32 not in current_platform.supported_dtypes:
|
||||
return dtype
|
||||
if runner_type == "pooling":
|
||||
return torch.float32
|
||||
return dtype
|
||||
else:
|
||||
raise ValueError(f"Unknown dtype: {head_dtype}")
|
||||
|
||||
|
||||
def _get_and_verify_max_len(
|
||||
hf_config: PretrainedConfig,
|
||||
tokenizer_config: Optional[dict],
|
||||
|
@ -5,7 +5,7 @@ from collections.abc import Mapping, Set
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum
|
||||
from itertools import groupby
|
||||
from typing import Callable, Optional, TypeVar, Union, cast
|
||||
from typing import Callable, Optional, TypeVar, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -362,14 +362,13 @@ class PoolerIdentity(PoolerActivation):
|
||||
class PoolerNormalize(PoolerActivation):
|
||||
|
||||
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
|
||||
x = F.normalize(pooled_data.float(), p=2, dim=-1)
|
||||
return x.to(pooled_data.dtype)
|
||||
return F.normalize(pooled_data, p=2, dim=-1)
|
||||
|
||||
|
||||
class PoolerMultiLabelClassify(PoolerActivation):
|
||||
|
||||
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
|
||||
return F.sigmoid(pooled_data.float()).to(pooled_data.dtype)
|
||||
return F.sigmoid(pooled_data)
|
||||
|
||||
|
||||
class PoolerClassify(PoolerActivation):
|
||||
@ -394,9 +393,9 @@ class PoolerClassify(PoolerActivation):
|
||||
pooled_data.shape[-1])
|
||||
|
||||
if num_labels < 2:
|
||||
return F.sigmoid(pooled_data.float()).to(pooled_data.dtype)
|
||||
return F.sigmoid(pooled_data)
|
||||
|
||||
return F.softmax(pooled_data.float(), dim=-1).to(pooled_data.dtype)
|
||||
return F.softmax(pooled_data, dim=-1)
|
||||
|
||||
|
||||
class LambdaPoolerActivation(PoolerActivation):
|
||||
@ -432,8 +431,9 @@ class EmbeddingPoolerHead(PoolerHead):
|
||||
from vllm.model_executor.models.adapters import _load_st_projector
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.projector = _load_st_projector(
|
||||
self.projector: Optional[nn.Module] = _load_st_projector(
|
||||
vllm_config.model_config) if vllm_config else None
|
||||
self.head_dtype = vllm_config.model_config.head_dtype
|
||||
|
||||
def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
|
||||
pooling_metadata: PoolingMetadata):
|
||||
@ -442,16 +442,11 @@ class EmbeddingPoolerHead(PoolerHead):
|
||||
pooled_data = torch.stack(pooled_data)
|
||||
# pooled_data shape: [batchsize, hidden_dimension]
|
||||
|
||||
pooled_data = pooled_data.to(self.head_dtype)
|
||||
|
||||
# Apply ST projector
|
||||
if self.projector is not None:
|
||||
projector = cast(nn.Module, self.projector)
|
||||
|
||||
def _proj(x: torch.Tensor) -> torch.Tensor:
|
||||
orig_dtype = x.dtype
|
||||
y = projector(x.to(torch.float32))
|
||||
return y.to(orig_dtype)
|
||||
|
||||
pooled_data = _proj(pooled_data)
|
||||
pooled_data = self.projector(pooled_data)
|
||||
# pooled_data shape: [batchsize, embedding_dimension]
|
||||
|
||||
pooling_params = get_pooling_params(pooling_metadata)
|
||||
@ -494,8 +489,18 @@ class RewardPoolerHead(PoolerHead):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(activation=PoolerClassify(static_num_labels=False))
|
||||
|
||||
from vllm.config import get_current_vllm_config
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.head_dtype = vllm_config.model_config.head_dtype
|
||||
|
||||
def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
|
||||
pooling_metadata: PoolingMetadata):
|
||||
|
||||
if isinstance(pooled_data, list):
|
||||
pooled_data = [p.to(self.head_dtype) for p in pooled_data]
|
||||
else:
|
||||
pooled_data = pooled_data.to(self.head_dtype)
|
||||
|
||||
pooling_params = get_pooling_params(pooling_metadata)
|
||||
|
||||
# for softmax
|
||||
@ -641,6 +646,7 @@ class ClassifierPooler(Pooler):
|
||||
self.act_fn = act_fn or PoolerClassify()
|
||||
self.logit_bias: Optional[
|
||||
float] = vllm_config.model_config.pooler_config.logit_bias
|
||||
self.head_dtype = vllm_config.model_config.head_dtype
|
||||
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"classify", "score"}
|
||||
@ -655,6 +661,8 @@ class ClassifierPooler(Pooler):
|
||||
pooled_data = torch.stack(pooled_data)
|
||||
# pooled_data shape: [batchsize, hidden_size]
|
||||
|
||||
pooled_data = pooled_data.to(self.head_dtype)
|
||||
|
||||
if self.classifier is not None:
|
||||
pooled_data = self.classifier(pooled_data)
|
||||
# pooled_data shape: [batchsize, num_labels]
|
||||
|
@ -62,7 +62,7 @@ def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]:
|
||||
linear = nn.Linear(layer_config.get("in_features", 768),
|
||||
layer_config.get("out_features", 768),
|
||||
bias=layer_config.get("bias", True),
|
||||
dtype=torch.float32)
|
||||
dtype=model_config.head_dtype)
|
||||
|
||||
if not _load_dense_weights(linear, folder, model_config):
|
||||
continue
|
||||
@ -70,7 +70,7 @@ def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]:
|
||||
layers.append(linear)
|
||||
if act_name := layer_config.get("activation_function"):
|
||||
layers.append(get_act_fn(act_name))
|
||||
return nn.Sequential(*layers).to(dtype=torch.float32)
|
||||
return nn.Sequential(*layers).to(dtype=model_config.head_dtype)
|
||||
except Exception:
|
||||
logger.exception("ST projector loading failed")
|
||||
|
||||
@ -105,15 +105,13 @@ def _load_dense_weights(linear: nn.Linear, folder: str,
|
||||
if weight_key in state_dict:
|
||||
weight_loader = getattr(linear.weight, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(linear.weight,
|
||||
state_dict[weight_key].to(torch.float32))
|
||||
weight_loader(linear.weight, state_dict[weight_key])
|
||||
|
||||
bias_key = weight_key.replace("weight", "bias")
|
||||
if linear.bias is not None and bias_key in state_dict:
|
||||
bias_loader = getattr(linear.bias, "weight_loader",
|
||||
default_weight_loader)
|
||||
bias_loader(linear.bias,
|
||||
state_dict[bias_key].to(torch.float32))
|
||||
bias_loader(linear.bias, state_dict[bias_key])
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception("Failed to load %s", filename)
|
||||
|
@ -562,7 +562,9 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
|
||||
self.bert = BertPoolingModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "bert"),
|
||||
embedding_class=BertEmbedding)
|
||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||
self.classifier = nn.Linear(config.hidden_size,
|
||||
config.num_labels,
|
||||
dtype=vllm_config.model_config.head_dtype)
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
@ -637,14 +637,14 @@ class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
self.new = GteNewModel(vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
add_pooling_layer=True)
|
||||
self.classifier = RowParallelLinear(config.hidden_size,
|
||||
config.num_labels,
|
||||
input_is_parallel=False,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "classifier"),
|
||||
return_bias=False)
|
||||
self.classifier = ReplicatedLinear(
|
||||
config.hidden_size,
|
||||
config.num_labels,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
params_dtype=vllm_config.model_config.head_dtype,
|
||||
prefix=maybe_prefix(prefix, "classifier"),
|
||||
return_bias=False)
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
@ -339,7 +339,10 @@ class GPT2ForSequenceClassification(nn.Module):
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.transformer = GPT2Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "gpt2"))
|
||||
self.score = nn.Linear(config.n_embd, config.num_labels, bias=False)
|
||||
self.score = nn.Linear(config.n_embd,
|
||||
config.num_labels,
|
||||
bias=False,
|
||||
dtype=vllm_config.model_config.head_dtype)
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
@ -348,7 +351,7 @@ class GPT2ForSequenceClassification(nn.Module):
|
||||
"encode":
|
||||
Pooler.for_encode(pooler_config),
|
||||
"classify":
|
||||
Pooler.for_classify(pooler_config, classifier=None),
|
||||
Pooler.for_classify(pooler_config, classifier=self.score),
|
||||
})
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
@ -367,8 +370,7 @@ class GPT2ForSequenceClassification(nn.Module):
|
||||
position_ids=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors)
|
||||
logits = self.score(hidden_states)
|
||||
return logits
|
||||
return hidden_states
|
||||
|
||||
|
||||
def _add_transformer_prefix(
|
||||
|
@ -423,13 +423,15 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
|
||||
delattr(self, attr)
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.v_head = RowParallelLinear(
|
||||
config.hidden_size,
|
||||
1,
|
||||
bias=False,
|
||||
input_is_parallel=False,
|
||||
prefix=maybe_prefix(prefix, "v_head"),
|
||||
)
|
||||
self.head_dtype = vllm_config.model_config.head_dtype
|
||||
|
||||
self.v_head = RowParallelLinear(config.hidden_size,
|
||||
1,
|
||||
bias=False,
|
||||
input_is_parallel=False,
|
||||
params_dtype=self.head_dtype,
|
||||
prefix=maybe_prefix(prefix, "v_head"),
|
||||
return_bias=False)
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
@ -446,5 +448,6 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
logits, _ = self.v_head(hidden_states)
|
||||
hidden_states = hidden_states.to(self.head_dtype)
|
||||
logits = self.v_head(hidden_states)
|
||||
return logits
|
||||
|
@ -613,7 +613,7 @@ class JambaForSequenceClassification(JambaForCausalLM):
|
||||
config.hidden_size,
|
||||
num_labels,
|
||||
bias=score_bias,
|
||||
dtype=torch.float32,
|
||||
dtype=vllm_config.model_config.head_dtype,
|
||||
)
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
|
@ -5,9 +5,9 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import BatchFeature, PretrainedConfig
|
||||
from transformers import BatchFeature
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.inputs import TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
@ -28,13 +28,17 @@ logger = init_logger(__name__)
|
||||
|
||||
class JinaVLScorer(nn.Module):
|
||||
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
def __init__(self, model_config: "ModelConfig"):
|
||||
super().__init__()
|
||||
config = model_config.hf_config
|
||||
head_dtype = model_config.head_dtype
|
||||
self.dense = ColumnParallelLinear(config.hidden_size,
|
||||
config.hidden_size,
|
||||
params_dtype=head_dtype,
|
||||
bias=True)
|
||||
self.out_proj = RowParallelLinear(config.hidden_size,
|
||||
config.num_labels,
|
||||
params_dtype=head_dtype,
|
||||
bias=True)
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
@ -88,11 +92,10 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration,
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "qwen2_vl"))
|
||||
config = vllm_config.model_config.hf_config
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.score = JinaVLScorer(config)
|
||||
self.score = JinaVLScorer(vllm_config.model_config)
|
||||
self.pooler = DispatchPooler({
|
||||
"encode":
|
||||
Pooler.for_encode(pooler_config),
|
||||
|
@ -306,7 +306,9 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
self.config = config
|
||||
self.model = ModernBertModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "modernbert"))
|
||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||
self.classifier = nn.Linear(config.hidden_size,
|
||||
config.num_labels,
|
||||
dtype=vllm_config.model_config.head_dtype)
|
||||
self.pooling = ModernBertPooler(config)
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
|
@ -53,15 +53,18 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self.quant_config = quant_config
|
||||
self.model = Qwen2Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.head_dtype = vllm_config.model_config.head_dtype
|
||||
|
||||
self.score = nn.Sequential(
|
||||
ColumnParallelLinear(config.hidden_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
params_dtype=self.head_dtype,
|
||||
return_bias=False),
|
||||
nn.ReLU(),
|
||||
RowParallelLinear(config.hidden_size,
|
||||
config.num_labels,
|
||||
params_dtype=self.head_dtype,
|
||||
quant_config=quant_config,
|
||||
return_bias=False),
|
||||
)
|
||||
@ -80,6 +83,7 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
hidden_states = hidden_states.to(self.head_dtype)
|
||||
logits = self.score(hidden_states)
|
||||
return logits
|
||||
|
||||
|
@ -8,7 +8,7 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import RobertaConfig
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool,
|
||||
DispatchPooler, Pooler)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -73,10 +73,16 @@ class RobertaEmbedding(nn.Module):
|
||||
class RobertaClassificationHead(nn.Module):
|
||||
"""Head for sentence-level classification tasks."""
|
||||
|
||||
def __init__(self, config: RobertaConfig):
|
||||
def __init__(self, model_config: "ModelConfig"):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
|
||||
config = model_config.hf_config
|
||||
head_dtype = model_config.head_dtype
|
||||
self.dense = nn.Linear(config.hidden_size,
|
||||
config.hidden_size,
|
||||
dtype=head_dtype)
|
||||
self.out_proj = nn.Linear(config.hidden_size,
|
||||
config.num_labels,
|
||||
dtype=head_dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# CLSPool has already been applied in `pooling`
|
||||
@ -184,7 +190,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
self.roberta = BertModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "bert"),
|
||||
embedding_class=RobertaEmbedding)
|
||||
self.classifier = RobertaClassificationHead(config)
|
||||
self.classifier = RobertaClassificationHead(vllm_config.model_config)
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
Reference in New Issue
Block a user