[New Model] Support BertForTokenClassification / Named Entity Recognition (NER) task (#24872)

Signed-off-by: wang.yuqi <noooop@126.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
wang.yuqi
2025-09-18 23:22:01 +08:00
committed by GitHub
parent 67244c86f0
commit 5f696c33b1
11 changed files with 257 additions and 2 deletions

View File

@ -554,6 +554,17 @@ If your model is not in the above list, we will try to automatically convert the
For process-supervised reward models such as `peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly,
e.g.: `--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`.
#### Token Classification
These models primarily support the [`LLM.encode`](./pooling_models.md#llmencode) API.
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
|--------------|--------|-------------------|-----------------------------|-----------------------------------------|---------------------|
| `BertForTokenClassification` | bert-based | `boltuix/NeuroBERT-NER` (see note), etc. | | | ✅︎ |
!!! note
Named Entity Recognition (NER) usage, please refer to <gh-file:examples/offline_inference/pooling/ner.py>, <gh-file:examples/online_serving/pooling/ner.py>.
[](){ #supported-mm-models }
## List of Multimodal Language Models

View File

@ -26,8 +26,14 @@ python examples/offline_inference/pooling/embed_jina_embeddings_v3.py
python examples/offline_inference/pooling/embed_matryoshka_fy.py
```
## Named Entity Recognition (NER) usage
```bash
python examples/offline_inference/pooling/ner.py
```
## Qwen3 reranker usage
```bash
python qwen3_reranker.py
python examples/offline_inference/pooling/qwen3_reranker.py
```

View File

@ -0,0 +1,54 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from https://huggingface.co/boltuix/NeuroBERT-NER
from argparse import Namespace
from vllm import LLM, EngineArgs
from vllm.utils import FlexibleArgumentParser
def parse_args():
parser = FlexibleArgumentParser()
parser = EngineArgs.add_cli_args(parser)
# Set example specific arguments
parser.set_defaults(
model="boltuix/NeuroBERT-NER",
runner="pooling",
enforce_eager=True,
trust_remote_code=True,
)
return parser.parse_args()
def main(args: Namespace):
# Sample prompts.
prompts = [
"Barack Obama visited Microsoft headquarters in Seattle on January 2025."
]
# Create an LLM.
llm = LLM(**vars(args))
tokenizer = llm.get_tokenizer()
label_map = llm.llm_engine.vllm_config.model_config.hf_config.id2label
# Run inference
outputs = llm.encode(prompts)
for prompt, output in zip(prompts, outputs):
logits = output.outputs.data
predictions = logits.argmax(dim=-1)
# Map predictions to labels
tokens = tokenizer.convert_ids_to_tokens(output.prompt_token_ids)
labels = [label_map[p.item()] for p in predictions]
# Print results
for token, label in zip(tokens, labels):
if token not in tokenizer.all_special_tokens:
print(f"{token:15}{label}")
if __name__ == "__main__":
args = parse_args()
main(args)

View File

@ -12,6 +12,12 @@ python examples/online_serving/pooling/cohere_rerank_client.py
python examples/online_serving/pooling/jinaai_rerank_client.py
```
## Named Entity Recognition (NER) usage
```bash
python examples/online_serving/pooling/ner.py
```
## Openai chat embedding for multimodal usage
```bash

View File

@ -0,0 +1,71 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from https://huggingface.co/boltuix/NeuroBERT-NER
"""
Example online usage of Pooling API for Named Entity Recognition (NER).
Run `vllm serve <model> --runner pooling`
to start up the server in vLLM. e.g.
vllm serve boltuix/NeuroBERT-NER
"""
import argparse
import requests
import torch
def post_http_request(prompt: dict, api_url: str) -> requests.Response:
headers = {"User-Agent": "Test Client"}
response = requests.post(api_url, headers=headers, json=prompt)
return response
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--model", type=str, default="boltuix/NeuroBERT-NER")
return parser.parse_args()
def main(args):
from transformers import AutoConfig, AutoTokenizer
api_url = f"http://{args.host}:{args.port}/pooling"
model_name = args.model
# Load tokenizer and config
tokenizer = AutoTokenizer.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name)
label_map = config.id2label
# Input text
text = "Barack Obama visited Microsoft headquarters in Seattle on January 2025."
prompt = {"model": model_name, "input": text}
pooling_response = post_http_request(prompt=prompt, api_url=api_url)
# Run inference
output = pooling_response.json()["data"][0]
logits = torch.tensor(output["data"])
predictions = logits.argmax(dim=-1)
inputs = tokenizer(text, return_tensors="pt")
# Map predictions to labels
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
labels = [label_map[p.item()] for p in predictions]
assert len(tokens) == len(predictions)
# Print results
for token, label in zip(tokens, labels):
if token not in tokenizer.all_special_tokens:
print(f"{token:15}{label}")
if __name__ == "__main__":
args = parse_args()
main(args)

View File

@ -0,0 +1,39 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from transformers import AutoModelForTokenClassification
from tests.models.utils import softmax
@pytest.mark.parametrize("model", ["boltuix/NeuroBERT-NER"])
# The float32 is required for this tiny model to pass the test.
@pytest.mark.parametrize("dtype", ["float"])
@torch.inference_mode
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
) -> None:
with vllm_runner(model, max_model_len=None, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.encode(example_prompts)
with hf_runner(model,
dtype=dtype,
auto_cls=AutoModelForTokenClassification) as hf_model:
tokenizer = hf_model.tokenizer
hf_outputs = []
for prompt in example_prompts:
inputs = tokenizer([prompt], return_tensors="pt")
inputs = hf_model.wrap_device(inputs)
output = hf_model.model(**inputs)
hf_outputs.append(softmax(output.logits[0]))
# check logits difference
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
hf_output = torch.tensor(hf_output).cpu().float()
vllm_output = torch.tensor(vllm_output).cpu().float()
assert torch.allclose(hf_output, vllm_output, 1e-2)

View File

@ -414,6 +414,7 @@ _SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = {
# [Cross-encoder]
"BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2"), # noqa: E501
"BertForTokenClassification": _HfExamplesInfo("boltuix/NeuroBERT-NER"),
"GteNewForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-multilingual-reranker-base", # noqa: E501
trust_remote_code=True,
hf_overrides={

View File

@ -943,6 +943,10 @@ class LLM:
considered legacy and may be deprecated in the future. You should
instead pass them via the `inputs` parameter.
"""
if self.supported_tasks == ["encode"] and pooling_task is None:
pooling_task = "encode"
if pooling_task is None:
if "embed" in self.supported_tasks:
pooling_task = "embed"

View File

@ -611,3 +611,55 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors)
@default_pooling_type("ALL")
class BertForTokenClassification(nn.Module):
is_pooling_model = True
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.head_dtype = vllm_config.model_config.head_dtype
self.num_labels = config.num_labels
self.bert = BertModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "bert"),
embedding_class=BertEmbedding)
self.classifier = nn.Linear(config.hidden_size,
config.num_labels,
dtype=self.head_dtype)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler({
"encode":
Pooler.for_encode(pooler_config),
})
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
loaded_params = loader.load_weights(weights)
return loaded_params
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if token_type_ids is not None:
assert self.bert.config.vocab_size < (1 << TOKEN_TYPE_SHIFT)
assert input_ids is not None
_encode_token_type_ids(input_ids, token_type_ids)
hidden_states = self.bert(input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors)
hidden_states = hidden_states.to(self.head_dtype)
return self.classifier(hidden_states)

View File

@ -193,6 +193,7 @@ _EMBEDDING_MODELS = {
_CROSS_ENCODER_MODELS = {
"BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
"BertForTokenClassification": ("bert", "BertForTokenClassification"),
"GteNewForSequenceClassification": ("bert_with_rope",
"GteNewForSequenceClassification"),
"ModernBertForSequenceClassification": ("modernbert",

View File

@ -720,6 +720,15 @@ class FlexAttentionImpl(AttentionImpl):
(query, key, value),
)
query = query[:, :, :num_actual_tokens, :]
if ((key_tensor.size(-2) > num_actual_tokens)
or (value_tensor.size(-2) > num_actual_tokens)):
# In the encoder-only model with torch.compile,
# qkv might be padded, which might cause exception.
# see: https://github.com/vllm-project/vllm/pull/24872#discussion_r2353252290
key_tensor = key_tensor[:, :, :num_actual_tokens, :]
value_tensor = value_tensor[:, :, :num_actual_tokens, :]
else:
assert self.attn_type == AttentionType.DECODER
key_cache, value_cache = kv_cache.unbind(0)
@ -744,7 +753,8 @@ class FlexAttentionImpl(AttentionImpl):
(query, key_cache, value_cache),
)
query = query[:, :, :num_actual_tokens, :]
query = query[:, :, :num_actual_tokens, :]
# Doesn't work for now -> constraint violation
# torch._dynamo.try_mark_dynamic(query, 2)