mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[New Model]: Support Qwen3 Embedding & Reranker (#19260)
This commit is contained in:
@ -387,18 +387,19 @@ See [this page](./pooling_models.md) for more information on how to use pooling
|
||||
|
||||
Specified using `--task embed`.
|
||||
|
||||
| Architecture | Models | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] |
|
||||
|--------------------------------------------------------|---------------------|---------------------------------------------------------------------------------------------------------------------|------------------------|-----------------------------|
|
||||
| `BertModel` | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | |
|
||||
| `Gemma2Model` | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | |
|
||||
| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ |
|
||||
| `GteModel` | Arctic-Embed-2.0-M | `Snowflake/snowflake-arctic-embed-m-v2.0`. | ︎ | |
|
||||
| `GteNewModel` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-base`, etc. | ︎ | ︎ |
|
||||
| `ModernBertModel` | ModernBERT-based | `Alibaba-NLP/gte-modernbert-base`, etc. | ︎ | ︎ |
|
||||
| `NomicBertModel` | Nomic BERT | `nomic-ai/nomic-embed-text-v1`, `nomic-ai/nomic-embed-text-v2-moe`, `Snowflake/snowflake-arctic-embed-m-long`, etc. | ︎ | ︎ |
|
||||
| `LlamaModel`, `LlamaForCausalLM`, `MistralModel`, etc. | Llama-based | `intfloat/e5-mistral-7b-instruct`, etc. | ✅︎ | ✅︎ |
|
||||
| `Qwen2Model`, `Qwen2ForCausalLM` | Qwen2-based | `ssmits/Qwen2-7B-Instruct-embed-base` (see note), `Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc. | ✅︎ | ✅︎ |
|
||||
| `RobertaModel`, `RobertaForMaskedLM` | RoBERTa-based | `sentence-transformers/all-roberta-large-v1`, etc. | | |
|
||||
| Architecture | Models | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] |
|
||||
|--------------------------------------------------------|---------------------|---------------------------------------------------------------------------------------------------------------------|----------------------|---------------------------|
|
||||
| `BertModel` | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | |
|
||||
| `Gemma2Model` | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | |
|
||||
| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ |
|
||||
| `GteModel` | Arctic-Embed-2.0-M | `Snowflake/snowflake-arctic-embed-m-v2.0`. | ︎ | |
|
||||
| `GteNewModel` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-base`, etc. | ︎ | ︎ |
|
||||
| `ModernBertModel` | ModernBERT-based | `Alibaba-NLP/gte-modernbert-base`, etc. | ︎ | ︎ |
|
||||
| `NomicBertModel` | Nomic BERT | `nomic-ai/nomic-embed-text-v1`, `nomic-ai/nomic-embed-text-v2-moe`, `Snowflake/snowflake-arctic-embed-m-long`, etc. | ︎ | ︎ |
|
||||
| `LlamaModel`, `LlamaForCausalLM`, `MistralModel`, etc. | Llama-based | `intfloat/e5-mistral-7b-instruct`, etc. | ✅︎ | ✅︎ |
|
||||
| `Qwen2Model`, `Qwen2ForCausalLM` | Qwen2-based | `ssmits/Qwen2-7B-Instruct-embed-base` (see note), `Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc. | ✅︎ | ✅︎ |
|
||||
| `Qwen3Model`, `Qwen3ForCausalLM` | Qwen3-based | `Qwen/Qwen3-Embedding-0.6B`, etc. | ✅︎ | ✅︎ |
|
||||
| `RobertaModel`, `RobertaForMaskedLM` | RoBERTa-based | `sentence-transformers/all-roberta-large-v1`, etc. | | |
|
||||
|
||||
!!! note
|
||||
`ssmits/Qwen2-7B-Instruct-embed-base` has an improperly defined Sentence Transformers config.
|
||||
@ -450,12 +451,19 @@ If your model is not in the above list, we will try to automatically convert the
|
||||
|
||||
Specified using `--task score`.
|
||||
|
||||
| Architecture | Models | Example HF Models |
|
||||
|---------------------------------------|-------------------|----------------------------------------------|
|
||||
| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. |
|
||||
| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. |
|
||||
| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. |
|
||||
| Architecture | Models | Example HF Models |
|
||||
|---------------------------------------|-------------------|--------------------------------------------------------------------------------------|
|
||||
| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. |
|
||||
| `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. |
|
||||
| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. |
|
||||
| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. |
|
||||
|
||||
!!! note
|
||||
Load the official original `Qwen3 Reranker` by using the following command. More information can be found at: <gh-file:examples/offline_inference/qwen3_reranker.py>.
|
||||
|
||||
```bash
|
||||
vllm serve Qwen/Qwen3-Reranker-0.6B --hf_overrides '{"architectures": ["Qwen3ForSequenceClassification"],"classifier_from_token": ["no", "yes"],"is_original_qwen3_reranker": true}'
|
||||
```
|
||||
[](){ #supported-mm-models }
|
||||
|
||||
## List of Multimodal Language Models
|
||||
|
77
examples/offline_inference/qwen3_reranker.py
Normal file
77
examples/offline_inference/qwen3_reranker.py
Normal file
@ -0,0 +1,77 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa: E501
|
||||
|
||||
from vllm import LLM
|
||||
|
||||
model_name = "Qwen/Qwen3-Reranker-0.6B"
|
||||
|
||||
# What is the difference between the official original version and one
|
||||
# that has been converted into a sequence classification model?
|
||||
# Qwen3-Reranker is a language model that doing reranker by using the
|
||||
# logits of "no" and "yes" tokens.
|
||||
# It needs to computing 151669 tokens logits, making this method extremely
|
||||
# inefficient, not to mention incompatible with the vllm score API.
|
||||
# A method for converting the original model into a sequence classification
|
||||
# model was proposed. See:https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
|
||||
# Models converted offline using this method can not only be more efficient
|
||||
# and support the vllm score API, but also make the init parameters more
|
||||
# concise, for example.
|
||||
# model = LLM(model="tomaarsen/Qwen3-Reranker-0.6B-seq-cls", task="score")
|
||||
|
||||
# If you want to load the official original version, the init parameters are
|
||||
# as follows.
|
||||
|
||||
model = LLM(
|
||||
model=model_name,
|
||||
task="score",
|
||||
hf_overrides={
|
||||
"architectures": ["Qwen3ForSequenceClassification"],
|
||||
"classifier_from_token": ["no", "yes"],
|
||||
"is_original_qwen3_reranker": True,
|
||||
},
|
||||
)
|
||||
|
||||
# Why do we need hf_overrides for the official original version:
|
||||
# vllm converts it to Qwen3ForSequenceClassification when loaded for
|
||||
# better performance.
|
||||
# - Firstly, we need using `"architectures": ["Qwen3ForSequenceClassification"],`
|
||||
# to manually route to Qwen3ForSequenceClassification.
|
||||
# - Then, we will extract the vector corresponding to classifier_from_token
|
||||
# from lm_head using `"classifier_from_token": ["no", "yes"]`.
|
||||
# - Third, we will convert these two vectors into one vector. The use of
|
||||
# conversion logic is controlled by `using "is_original_qwen3_reranker": True`.
|
||||
|
||||
# Please use the query_template and document_template to format the query and
|
||||
# document for better reranker results.
|
||||
|
||||
prefix = '<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n'
|
||||
suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
|
||||
|
||||
query_template = "{prefix}<Instruct>: {instruction}\n<Query>: {query}\n"
|
||||
document_template = "<Document>: {doc}{suffix}"
|
||||
|
||||
if __name__ == "__main__":
|
||||
instruction = (
|
||||
"Given a web search query, retrieve relevant passages that answer the query"
|
||||
)
|
||||
|
||||
queries = [
|
||||
"What is the capital of China?",
|
||||
"Explain gravity",
|
||||
]
|
||||
|
||||
documents = [
|
||||
"The capital of China is Beijing.",
|
||||
"Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.",
|
||||
]
|
||||
|
||||
queries = [
|
||||
query_template.format(prefix=prefix, instruction=instruction, query=query)
|
||||
for query in queries
|
||||
]
|
||||
documents = [document_template.format(doc=doc, suffix=suffix) for doc in documents]
|
||||
|
||||
outputs = model.score(queries, documents)
|
||||
|
||||
print([output.outputs.score for output in outputs])
|
@ -45,6 +45,15 @@ MODELS = [
|
||||
EmbedModelInfo("Alibaba-NLP/gte-modernbert-base",
|
||||
architecture="ModernBertModel",
|
||||
enable_test=True),
|
||||
########## Qwen3ForCausalLM
|
||||
EmbedModelInfo("Qwen/Qwen3-Embedding-0.6B",
|
||||
architecture="Qwen3ForCausalLM",
|
||||
dtype="float32",
|
||||
enable_test=True),
|
||||
EmbedModelInfo("Qwen/Qwen3-Embedding-4B",
|
||||
architecture="Qwen3ForCausalLM",
|
||||
dtype="float32",
|
||||
enable_test=False),
|
||||
]
|
||||
|
||||
|
||||
|
87
tests/models/language/pooling/test_qwen3_reranker.py
Normal file
87
tests/models/language/pooling/test_qwen3_reranker.py
Normal file
@ -0,0 +1,87 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import pytest
|
||||
|
||||
model_name = "Qwen/Qwen3-Reranker-4B"
|
||||
|
||||
text_1 = "What is the capital of France?"
|
||||
texts_2 = [
|
||||
"The capital of Brazil is Brasilia.",
|
||||
"The capital of France is Paris.",
|
||||
]
|
||||
|
||||
|
||||
def vllm_reranker(model_name):
|
||||
from vllm import LLM
|
||||
|
||||
model = LLM(model=model_name,
|
||||
task="score",
|
||||
hf_overrides={
|
||||
"architectures": ["Qwen3ForSequenceClassification"],
|
||||
"classifier_from_token": ["no", "yes"],
|
||||
"is_original_qwen3_reranker": True,
|
||||
},
|
||||
dtype="float32")
|
||||
|
||||
text_1 = "What is the capital of France?"
|
||||
texts_2 = [
|
||||
"The capital of Brazil is Brasilia.",
|
||||
"The capital of France is Paris.",
|
||||
]
|
||||
|
||||
outputs = model.score(text_1, texts_2)
|
||||
|
||||
return [output.outputs.score for output in outputs]
|
||||
|
||||
|
||||
def hf_reranker(model_name):
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name).eval()
|
||||
|
||||
token_false_id = tokenizer.convert_tokens_to_ids("no")
|
||||
token_true_id = tokenizer.convert_tokens_to_ids("yes")
|
||||
|
||||
max_length = 8192
|
||||
|
||||
def process_inputs(pairs):
|
||||
inputs = tokenizer(pairs,
|
||||
padding=False,
|
||||
truncation='longest_first',
|
||||
return_attention_mask=False,
|
||||
max_length=max_length)
|
||||
for i, ele in enumerate(inputs['input_ids']):
|
||||
inputs['input_ids'][i] = ele
|
||||
inputs = tokenizer.pad(inputs,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
max_length=max_length)
|
||||
for key in inputs:
|
||||
inputs[key] = inputs[key].to(model.device)
|
||||
return inputs
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_logits(inputs, **kwargs):
|
||||
batch_scores = model(**inputs).logits[:, -1, :]
|
||||
true_vector = batch_scores[:, token_true_id]
|
||||
false_vector = batch_scores[:, token_false_id]
|
||||
batch_scores = torch.stack([false_vector, true_vector], dim=1)
|
||||
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
|
||||
scores = batch_scores[:, 1].exp().tolist()
|
||||
return scores
|
||||
|
||||
pairs = [(text_1, texts_2[0]), (text_1, texts_2[1])]
|
||||
inputs = process_inputs(pairs)
|
||||
scores = compute_logits(inputs)
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", [model_name])
|
||||
def test_model(model_name):
|
||||
hf_outputs = hf_reranker(model_name)
|
||||
vllm_outputs = vllm_reranker(model_name)
|
||||
|
||||
assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.01)
|
||||
assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.01)
|
73
tests/models/language/pooling/test_qwen3_reranker_seq_cls.py
Normal file
73
tests/models/language/pooling/test_qwen3_reranker_seq_cls.py
Normal file
@ -0,0 +1,73 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import pytest
|
||||
|
||||
model_name = "tomaarsen/Qwen3-Reranker-0.6B-seq-cls"
|
||||
|
||||
text_1 = "What is the capital of France?"
|
||||
texts_2 = [
|
||||
"The capital of Brazil is Brasilia.",
|
||||
"The capital of France is Paris.",
|
||||
]
|
||||
|
||||
|
||||
def vllm_reranker(model_name):
|
||||
from vllm import LLM
|
||||
|
||||
model = LLM(model=model_name, task="score")
|
||||
outputs = model.score(text_1, texts_2)
|
||||
|
||||
return [output.outputs.score for output in outputs]
|
||||
|
||||
|
||||
def hf_reranker(model_name):
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name).eval()
|
||||
|
||||
token_false_id = tokenizer.convert_tokens_to_ids("no")
|
||||
token_true_id = tokenizer.convert_tokens_to_ids("yes")
|
||||
|
||||
max_length = 8192
|
||||
|
||||
def process_inputs(pairs):
|
||||
inputs = tokenizer(pairs,
|
||||
padding=False,
|
||||
truncation='longest_first',
|
||||
return_attention_mask=False,
|
||||
max_length=max_length)
|
||||
for i, ele in enumerate(inputs['input_ids']):
|
||||
inputs['input_ids'][i] = ele
|
||||
inputs = tokenizer.pad(inputs,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
max_length=max_length)
|
||||
for key in inputs:
|
||||
inputs[key] = inputs[key].to(model.device)
|
||||
return inputs
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_logits(inputs, **kwargs):
|
||||
batch_scores = model(**inputs).logits[:, -1, :]
|
||||
true_vector = batch_scores[:, token_true_id]
|
||||
false_vector = batch_scores[:, token_false_id]
|
||||
batch_scores = torch.stack([false_vector, true_vector], dim=1)
|
||||
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
|
||||
scores = batch_scores[:, 1].exp().tolist()
|
||||
return scores
|
||||
|
||||
pairs = [(text_1, texts_2[0]), (text_1, texts_2[1])]
|
||||
inputs = process_inputs(pairs)
|
||||
scores = compute_logits(inputs)
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", [model_name])
|
||||
def test_model(model_name):
|
||||
hf_outputs = hf_reranker(model_name)
|
||||
vllm_outputs = vllm_reranker(model_name)
|
||||
|
||||
assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.01)
|
||||
assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.01)
|
@ -238,6 +238,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"),
|
||||
"Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"),
|
||||
"Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"),
|
||||
"Qwen3ForSequenceClassification": _HfExamplesInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls"), # noqa: E501
|
||||
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"),
|
||||
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b", # noqa: E501
|
||||
v0_only=True),
|
||||
|
@ -38,13 +38,15 @@ from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .interfaces import SupportsCrossEncoding, SupportsLoRA, SupportsPP
|
||||
from .qwen2 import Qwen2MLP as Qwen3MLP
|
||||
from .qwen2 import Qwen2Model
|
||||
from .utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix
|
||||
@ -319,3 +321,122 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
|
||||
class Qwen3ForSequenceClassification(nn.Module, SupportsLoRA,
|
||||
SupportsCrossEncoding):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.prefix = prefix
|
||||
self.model = Qwen3Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.score = RowParallelLinear(config.hidden_size,
|
||||
config.num_labels,
|
||||
quant_config=quant_config,
|
||||
input_is_parallel=False,
|
||||
bias=False,
|
||||
prefix=maybe_prefix(prefix, "score"))
|
||||
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.LAST,
|
||||
normalize=False,
|
||||
softmax=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return self.model(input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors)
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Optional[PoolerOutput]:
|
||||
hidden_states = self._pooler.extract_states(hidden_states,
|
||||
pooling_metadata)
|
||||
logits, _ = self.score(hidden_states)
|
||||
pooled_data = self._pooler.head(logits, pooling_metadata)
|
||||
pooled_outputs = [
|
||||
self._pooler.build_output(data.squeeze(-1)) for data in pooled_data
|
||||
]
|
||||
return PoolerOutput(outputs=pooled_outputs)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
is_original_qwen3_reranker = getattr(self.config,
|
||||
"is_original_qwen3_reranker",
|
||||
False)
|
||||
|
||||
if not is_original_qwen3_reranker:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
return self.load_weights_from_original_qwen3_reranker(weights)
|
||||
|
||||
def load_weights_from_original_qwen3_reranker(
|
||||
self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
tokens = getattr(self.config, "classifier_from_token", None)
|
||||
assert tokens is not None and len(tokens) == 2, \
|
||||
("Try loading the original Qwen3 Reranker?, see: "
|
||||
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/qwen3_reranker.py")
|
||||
|
||||
self.config.num_labels = 1
|
||||
model_config = self.vllm_config.model_config
|
||||
|
||||
device = self.score.weight.device
|
||||
self.score = RowParallelLinear(self.config.hidden_size,
|
||||
self.config.num_labels,
|
||||
quant_config=self.quant_config,
|
||||
input_is_parallel=False,
|
||||
bias=False,
|
||||
prefix=maybe_prefix(
|
||||
self.prefix, "score")).to(device)
|
||||
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(self.config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
quant_config=self.quant_config,
|
||||
prefix=maybe_prefix(
|
||||
self.prefix, "lm_head"))
|
||||
|
||||
loader = AutoWeightsLoader(self)
|
||||
loaded_weights = loader.load_weights(weights)
|
||||
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
tokenizer = get_tokenizer(
|
||||
model_config.tokenizer,
|
||||
revision=model_config.tokenizer_revision,
|
||||
tokenizer_mode=model_config.tokenizer_mode,
|
||||
trust_remote_code=model_config.trust_remote_code)
|
||||
|
||||
a = tokenizer.convert_tokens_to_ids(tokens[0])
|
||||
b = tokenizer.convert_tokens_to_ids(tokens[1])
|
||||
weight = self.lm_head.weight.data[b].to(
|
||||
device) - self.lm_head.weight.data[a].to(device)
|
||||
self.score.weight.data.copy_(weight)
|
||||
|
||||
del self.lm_head
|
||||
loaded_weights.add("classifier.weight")
|
||||
loaded_weights.discard("lm_head.weight")
|
||||
|
@ -172,6 +172,7 @@ _CROSS_ENCODER_MODELS = {
|
||||
"RobertaForSequenceClassification"),
|
||||
"ModernBertForSequenceClassification": ("modernbert",
|
||||
"ModernBertForSequenceClassification"),
|
||||
"Qwen3ForSequenceClassification": ("qwen3", "Qwen3ForSequenceClassification"), # noqa: E501
|
||||
}
|
||||
|
||||
_MULTIMODAL_MODELS = {
|
||||
|
Reference in New Issue
Block a user