mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Feature] Add support for naver/splade-v3 (BERT-based sparse embedding model) (#26339)
Signed-off-by: gjgjos <gjgjos@naver.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
122
tests/models/language/pooling/test_splade_sparse_pooler.py
Normal file
122
tests/models/language/pooling/test_splade_sparse_pooler.py
Normal file
@ -0,0 +1,122 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import types
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.model_executor.models.bert import (
|
||||
BertMLMHead,
|
||||
SPLADESparsePooler,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# 1) Functional test: SPLADE formula correctness (no HF download needed)
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("B,T,H,V", [(2, 3, 5, 7)])
|
||||
def test_splade_pooler_matches_reference_formula(B, T, H, V):
|
||||
"""Ensure SPLADESparsePooler forward() matches the mathematical formula:
|
||||
log1p(relu(logits)) -> max over sequence length (after masking)."""
|
||||
torch.manual_seed(0)
|
||||
|
||||
# Prepare [B] sequences of shape [T, H]
|
||||
hs_list = [torch.randn(T, H) for _ in range(B)]
|
||||
|
||||
# Simulate PoolingMetadata (only required fields)
|
||||
prompt_lens = [T, T - 1]
|
||||
token_ids = torch.tensor(
|
||||
[
|
||||
[101, 5, 102], # Batch 0: [CLS], token, [SEP]
|
||||
[101, 6, 6], # Batch 1: [CLS], token, token (last token ignored)
|
||||
],
|
||||
dtype=torch.long,
|
||||
)
|
||||
meta = types.SimpleNamespace(prompt_lens=prompt_lens, prompt_token_ids=token_ids)
|
||||
|
||||
# MLM head (prefer BertMLMHead, fallback to Linear if unavailable)
|
||||
try:
|
||||
mlm_head = BertMLMHead(hidden_size=H, vocab_size=V, layer_norm_eps=1e-12)
|
||||
except Exception:
|
||||
mlm_head = nn.Linear(H, V, bias=True)
|
||||
|
||||
# Forward pass through SPLADE pooler
|
||||
pooler = SPLADESparsePooler(mlm_head=mlm_head, pooling="max", remove_cls_sep=True)
|
||||
pooled = pooler(hidden_states=hs_list, pooling_metadata=meta) # list of [V]
|
||||
|
||||
# Basic output checks
|
||||
assert isinstance(pooled, list) and len(pooled) == B
|
||||
for vec in pooled:
|
||||
assert vec.shape == (V,)
|
||||
assert torch.isfinite(vec).all()
|
||||
assert (vec >= 0).all(), "SPLADE outputs must be non-negative."
|
||||
|
||||
# Reference implementation for comparison
|
||||
def ref_one(hs: torch.Tensor, L: int, tid_row: torch.Tensor) -> torch.Tensor:
|
||||
keep = torch.ones(L, dtype=torch.bool)
|
||||
if L > 0 and tid_row[0].item() == 101: # remove CLS
|
||||
keep[0] = False
|
||||
if L > 0 and tid_row[L - 1].item() == 102: # remove SEP
|
||||
keep[L - 1] = False
|
||||
|
||||
valid = hs[:L][keep[:L]]
|
||||
if valid.numel() == 0:
|
||||
return torch.zeros(V, dtype=torch.float32)
|
||||
|
||||
logits = mlm_head(valid) # [L', V]
|
||||
scores = torch.log1p(torch.relu(logits)) # [L', V]
|
||||
return scores.max(dim=0).values.to(torch.float32)
|
||||
|
||||
torch.testing.assert_close(
|
||||
pooled[0],
|
||||
ref_one(hs_list[0], prompt_lens[0], token_ids[0]),
|
||||
rtol=1e-4,
|
||||
atol=1e-4,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
pooled[1],
|
||||
ref_one(hs_list[1], prompt_lens[1], token_ids[1]),
|
||||
rtol=1e-4,
|
||||
atol=1e-4,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# 2) Integration smoke test: end-to-end embedding path wiring
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.cpu_model
|
||||
def test_bert_splade_sparse_embed_smoke(vllm_runner, monkeypatch):
|
||||
"""Ensure BertSpladeSparseEmbeddingModel loads and produces sparse embeddings."""
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
MODEL_ID = "hf-internal-testing/tiny-random-bert"
|
||||
hf_overrides = {"architectures": ["BertSpladeSparseEmbeddingModel"]}
|
||||
|
||||
# Enforce CPU-only execution (optional)
|
||||
monkeypatch.setenv("CUDA_VISIBLE_DEVICES", "")
|
||||
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False")
|
||||
|
||||
tok = AutoTokenizer.from_pretrained(MODEL_ID)
|
||||
vocab_size = tok.vocab_size
|
||||
|
||||
# The embed path should route through SPLADESparsePooler
|
||||
with vllm_runner(
|
||||
MODEL_ID,
|
||||
runner="pooling",
|
||||
max_model_len=64,
|
||||
hf_overrides=hf_overrides,
|
||||
) as vm:
|
||||
outs = vm.embed(["hello world", "splade sparse test"])
|
||||
|
||||
# Basic sanity checks
|
||||
assert len(outs) == 2
|
||||
assert outs[0].shape[0] == vocab_size
|
||||
assert outs[1].shape[0] == vocab_size
|
||||
assert np.isfinite(outs[0]).all() and (outs[0] >= 0).all()
|
||||
assert np.isfinite(outs[1]).all() and (outs[1] >= 0).all()
|
@ -486,6 +486,9 @@ _EMBEDDING_EXAMPLE_MODELS = {
|
||||
"RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2"),
|
||||
"RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1"),
|
||||
"XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small"),
|
||||
"BertSpladeSparseEmbeddingModel": _HfExamplesInfo(
|
||||
"naver/splade-v3", is_available_online=False
|
||||
),
|
||||
# [Multimodal]
|
||||
"CLIPModel": _HfExamplesInfo("openai/clip-vit-base-patch32"),
|
||||
"LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"),
|
||||
|
@ -572,6 +572,220 @@ def _decode_token_type_ids(input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return token_type_ids
|
||||
|
||||
|
||||
class BertMLMHead(nn.Module):
|
||||
def __init__(
|
||||
self, hidden_size: int, vocab_size: int, layer_norm_eps: float = 1e-12
|
||||
):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(hidden_size, hidden_size)
|
||||
self.activation = nn.GELU()
|
||||
self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
|
||||
self.decoder = nn.Linear(hidden_size, vocab_size, bias=True)
|
||||
|
||||
def tie_weights_with_embeddings(self, embeddings_weight: torch.Tensor):
|
||||
self.decoder.weight = embeddings_weight
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
x = self.dense(hidden_states)
|
||||
x = self.activation(x)
|
||||
x = self.layer_norm(x)
|
||||
logits = self.decoder(x)
|
||||
return logits
|
||||
|
||||
|
||||
class SPLADESparsePooler(Pooler):
|
||||
"""
|
||||
SPLADE sparse pooling:
|
||||
logits = mlm_head(hidden_states)
|
||||
-> log1p(relu(logits))
|
||||
-> (max|sum over L)
|
||||
-> [V]
|
||||
|
||||
Padding is masked with an attention mask,
|
||||
[CLS]/[SEP] is removed (selected),
|
||||
and then pooled.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mlm_head: nn.Module,
|
||||
cls_token_id: Optional[int] = 101,
|
||||
sep_token_id: Optional[int] = 102,
|
||||
pooling: str = "max",
|
||||
remove_cls_sep: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
assert pooling in ("max", "sum")
|
||||
self.mlm_head = mlm_head
|
||||
self.cls_token_id = cls_token_id
|
||||
self.sep_token_id = sep_token_id
|
||||
self.pooling = pooling
|
||||
self.remove_cls_sep = remove_cls_sep
|
||||
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"embed"}
|
||||
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
return PoolingParamsUpdate(requires_token_ids=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert isinstance(hidden_states, torch.Tensor) and hidden_states.dim() == 2
|
||||
|
||||
lens_tensor: torch.Tensor = pooling_metadata.prompt_lens
|
||||
lens: list[int] = lens_tensor.tolist()
|
||||
B: int = len(lens)
|
||||
|
||||
token_ids = pooling_metadata.prompt_token_ids
|
||||
offset = 0
|
||||
pooled_list: list[torch.Tensor] = []
|
||||
|
||||
for i in range(B):
|
||||
L = int(lens[i])
|
||||
hs = hidden_states[offset : offset + L]
|
||||
|
||||
start_idx = 0
|
||||
end_idx = L
|
||||
if self.remove_cls_sep and token_ids is not None:
|
||||
if (
|
||||
self.cls_token_id is not None
|
||||
and token_ids[i, 0].item() == self.cls_token_id
|
||||
):
|
||||
start_idx = 1
|
||||
if (
|
||||
self.sep_token_id is not None
|
||||
and token_ids[i, L - 1].item() == self.sep_token_id
|
||||
):
|
||||
end_idx = max(start_idx, L - 1)
|
||||
|
||||
if end_idx <= start_idx:
|
||||
V = int(self.mlm_head.decoder.out_features)
|
||||
pooled_list.append(hs.new_zeros((V,)))
|
||||
offset += L
|
||||
continue
|
||||
|
||||
logits_i = self.mlm_head(hs[start_idx:end_idx])
|
||||
scores_i = torch.log1p(torch.relu(logits_i))
|
||||
|
||||
if self.pooling == "sum":
|
||||
pooled_i = scores_i.sum(dim=0)
|
||||
else: # "max"
|
||||
pooled_i = scores_i.max(dim=0).values
|
||||
|
||||
pooled_list.append(pooled_i.contiguous())
|
||||
offset += L
|
||||
|
||||
return torch.stack(pooled_list, dim=0).contiguous()
|
||||
|
||||
|
||||
@default_pooling_type("CLS")
|
||||
class BertSpladeSparseEmbeddingModel(BertEmbeddingModel):
|
||||
"""
|
||||
BertEmbeddingModel + SPLADE sparse embedding.
|
||||
- Make logits by self.mlm_head
|
||||
- pooler: SPLADESparsePooler(mlm_head...)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, *, vllm_config: VllmConfig, prefix: str = "", splade_pooling: str = "max"
|
||||
):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
cfg = vllm_config.model_config.hf_config
|
||||
|
||||
# MLM head
|
||||
self.mlm_head = BertMLMHead(
|
||||
hidden_size=cfg.hidden_size,
|
||||
vocab_size=cfg.vocab_size,
|
||||
layer_norm_eps=getattr(cfg, "layer_norm_eps", 1e-12),
|
||||
)
|
||||
|
||||
self._splade_pooling = splade_pooling
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
self.pooler = self._build_pooler(pooler_config)
|
||||
|
||||
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
|
||||
cfg = self.model.config
|
||||
|
||||
if not hasattr(self, "mlm_head"):
|
||||
self.mlm_head = BertMLMHead(
|
||||
hidden_size=cfg.hidden_size,
|
||||
vocab_size=cfg.vocab_size,
|
||||
layer_norm_eps=getattr(cfg, "layer_norm_eps", 1e-12),
|
||||
)
|
||||
|
||||
pooling_mode = getattr(self, "_splade_pooling", "max")
|
||||
|
||||
cls_id = getattr(cfg, "cls_token_id", None)
|
||||
sep_id = getattr(cfg, "sep_token_id", None)
|
||||
|
||||
return DispatchPooler(
|
||||
{
|
||||
"encode": Pooler.for_encode(pooler_config),
|
||||
"embed": SPLADESparsePooler(
|
||||
mlm_head=self.mlm_head,
|
||||
cls_token_id=cls_id,
|
||||
sep_token_id=sep_id,
|
||||
pooling=pooling_mode, # "max" or "sum"
|
||||
remove_cls_sep=True,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
if not hasattr(self, "mlm_head"):
|
||||
cfg = self.model.config
|
||||
self.mlm_head = BertMLMHead(
|
||||
hidden_size=cfg.hidden_size,
|
||||
vocab_size=cfg.vocab_size,
|
||||
layer_norm_eps=getattr(cfg, "layer_norm_eps", 1e-12),
|
||||
)
|
||||
|
||||
def _strip(name: str) -> str:
|
||||
for p in ("model.", "bert."):
|
||||
if name.startswith(p):
|
||||
name = name[len(p) :]
|
||||
return name
|
||||
|
||||
weights_list = list(weights)
|
||||
model_side: list[tuple[str, torch.Tensor]] = []
|
||||
mlm_side: list[tuple[str, torch.Tensor]] = []
|
||||
|
||||
for k, w in weights_list:
|
||||
name = _strip(k)
|
||||
if name.startswith("cls.predictions."):
|
||||
mlm_side.append((name, w))
|
||||
else:
|
||||
model_side.append((name, w))
|
||||
|
||||
loaded: set[str] = set()
|
||||
loaded_model = self.model.load_weights(model_side)
|
||||
loaded.update({"model." + n for n in loaded_model})
|
||||
|
||||
if mlm_side:
|
||||
name_map = {
|
||||
"cls.predictions.transform.dense.weight": "mlm_head.dense.weight",
|
||||
"cls.predictions.transform.dense.bias": "mlm_head.dense.bias",
|
||||
("cls.predictions.transform.LayerNorm.weight"): (
|
||||
"mlm_head.layer_norm.weight"
|
||||
),
|
||||
("cls.predictions.transform.LayerNorm.bias"): (
|
||||
"mlm_head.layer_norm.bias"
|
||||
),
|
||||
"cls.predictions.decoder.weight": "mlm_head.decoder.weight",
|
||||
"cls.predictions.decoder.bias": "mlm_head.decoder.bias",
|
||||
}
|
||||
remapped = [(name_map[n], w) for n, w in mlm_side if n in name_map]
|
||||
if remapped:
|
||||
loaded_mlm = AutoWeightsLoader(self).load_weights(remapped)
|
||||
loaded.update(loaded_mlm)
|
||||
|
||||
return loaded
|
||||
|
||||
|
||||
@default_pooling_type("CLS")
|
||||
class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQuant):
|
||||
"""A model that uses Bert to provide embedding functionalities.
|
||||
|
@ -172,6 +172,7 @@ _TEXT_GENERATION_MODELS = {
|
||||
_EMBEDDING_MODELS = {
|
||||
# [Text-only]
|
||||
"BertModel": ("bert", "BertEmbeddingModel"),
|
||||
"BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"),
|
||||
"DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
|
||||
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
|
||||
"Gemma3TextModel": ("gemma3", "Gemma3Model"),
|
||||
|
Reference in New Issue
Block a user