Improve the output precision of embedding models (#19092)

This commit is contained in:
wang.yuqi
2025-06-04 19:48:57 +08:00
committed by GitHub
parent 8711bc5e68
commit 35cf32df30
8 changed files with 69 additions and 28 deletions

View File

@ -56,14 +56,10 @@ def correctness_test_embed_models(hf_runner,
max_model_len=None,
**vllm_extra_kwargs) as vllm_model:
vllm_outputs = vllm_model.encode(example_prompts)
vllm_dtype = vllm_model.model.llm_engine.model_config.dtype
model_dtype = getattr(
vllm_model.model.llm_engine.model_config.hf_config, "torch_dtype",
vllm_dtype)
with hf_runner(
model_info.name,
dtype=model_dtype,
dtype="float32",
is_sentence_transformer=True,
) as hf_model:

View File

@ -7,7 +7,6 @@ import numpy as np
import pytest
from tests.models.utils import EmbedModelInfo
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
# Most models on the STS12 task (See #17175):
# - Model implementation and minor changes in tensor dtype
@ -104,17 +103,18 @@ def mteb_test_embed_models(hf_runner,
MTEB_EMBED_TASKS)
vllm_dtype = vllm_model.model.llm_engine.model_config.dtype
with set_default_torch_dtype(vllm_dtype) and hf_runner(
model_info.name, is_sentence_transformer=True,
dtype=vllm_dtype) as hf_model:
with hf_runner(model_info.name,
is_sentence_transformer=True,
dtype="float32") as hf_model:
if hf_model_callback is not None:
hf_model_callback(hf_model)
st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS)
st_dtype = next(hf_model.model.parameters()).dtype
print("VLLM:", vllm_main_score)
print("SentenceTransformers:", st_main_score)
print("VLLM:", vllm_dtype, vllm_main_score)
print("SentenceTransformers:", st_dtype, st_main_score)
print("Difference:", st_main_score - vllm_main_score)
assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_EMBED_TOL)

View File

@ -11,27 +11,21 @@ MODELS = [
########## BertModel
EmbedModelInfo("thenlper/gte-large",
architecture="BertModel",
dtype="float32",
enable_test=True),
EmbedModelInfo("thenlper/gte-base",
architecture="BertModel",
dtype="float32",
enable_test=False),
EmbedModelInfo("thenlper/gte-small",
architecture="BertModel",
dtype="float32",
enable_test=False),
EmbedModelInfo("thenlper/gte-large-zh",
architecture="BertModel",
dtype="float32",
enable_test=False),
EmbedModelInfo("thenlper/gte-base-zh",
architecture="BertModel",
dtype="float32",
enable_test=False),
EmbedModelInfo("thenlper/gte-small-zh",
architecture="BertModel",
dtype="float32",
enable_test=False),
########### NewModel
EmbedModelInfo("Alibaba-NLP/gte-multilingual-base",
@ -46,7 +40,6 @@ MODELS = [
########### Qwen2ForCausalLM
EmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct",
architecture="Qwen2ForCausalLM",
dtype="float32",
enable_test=True),
########## ModernBertModel
EmbedModelInfo("Alibaba-NLP/gte-modernbert-base",

View File

@ -0,0 +1,46 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
from ...utils import EmbedModelInfo
from .embed_utils import correctness_test_embed_models
from .mteb_utils import mteb_test_embed_models
MODELS = [
########## BertModel
EmbedModelInfo("intfloat/e5-small",
architecture="BertModel",
enable_test=True),
EmbedModelInfo("intfloat/e5-base",
architecture="BertModel",
enable_test=False),
EmbedModelInfo("intfloat/e5-large",
architecture="BertModel",
enable_test=False),
EmbedModelInfo("intfloat/multilingual-e5-small",
architecture="BertModel",
enable_test=False),
########## XLMRobertaModel
EmbedModelInfo("intfloat/multilingual-e5-base",
architecture="XLMRobertaModel",
enable_test=True),
EmbedModelInfo("intfloat/multilingual-e5-large",
architecture="XLMRobertaModel",
enable_test=False),
EmbedModelInfo("intfloat/multilingual-e5-large-instruct",
architecture="XLMRobertaModel",
enable_test=False),
]
@pytest.mark.parametrize("model_info", MODELS)
def test_embed_models_mteb(hf_runner, vllm_runner,
model_info: EmbedModelInfo) -> None:
mteb_test_embed_models(hf_runner, vllm_runner, model_info)
@pytest.mark.parametrize("model_info", MODELS)
def test_embed_models_correctness(hf_runner, vllm_runner,
model_info: EmbedModelInfo,
example_prompts) -> None:
correctness_test_embed_models(hf_runner, vllm_runner, model_info,
example_prompts)

View File

@ -32,8 +32,7 @@ TEXTS_2 = [
EMBEDDING_MODELS = [
EmbedModelInfo("jinaai/jina-embeddings-v3",
architecture="XLMRobertaModel",
is_matryoshka=True,
dtype="float32")
is_matryoshka=True)
]

View File

@ -9,18 +9,15 @@ from .mteb_utils import mteb_test_embed_models
MODELS = [
EmbedModelInfo("nomic-ai/nomic-embed-text-v1",
architecture="NomicBertModel",
dtype="float32",
enable_test=True),
EmbedModelInfo("nomic-ai/nomic-embed-text-v1.5",
architecture="NomicBertModel",
dtype="float32",
enable_test=False),
EmbedModelInfo("nomic-ai/CodeRankEmbed",
architecture="NomicBertModel",
enable_test=False),
EmbedModelInfo("nomic-ai/nomic-embed-text-v2-moe",
architecture="NomicBertModel",
dtype="float32",
enable_test=True)
]

View File

@ -414,10 +414,15 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.model(input_ids=input_ids,
position_ids=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors)
hidden_states = self.model(input_ids=input_ids,
position_ids=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors)
# convert the embedding output to float32,
# otherwise precision will be lost significantly
hidden_states = hidden_states.to(torch.float32)
return hidden_states
def pooler(
self,

View File

@ -432,7 +432,12 @@ class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant):
else:
hidden_states = self.embeddings(input_ids=input_ids,
token_type_ids=token_type_ids)
return self.encoder(positions, hidden_states)
hidden_states = self.encoder(positions, hidden_states)
# convert the embedding output to float32,
# otherwise precision will be lost significantly
hidden_states = hidden_states.to(torch.float32)
return hidden_states
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: