mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[New Model] Support BertForTokenClassification / Named Entity Recognition (NER) task (#24872)
Signed-off-by: wang.yuqi <noooop@126.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@ -554,6 +554,17 @@ If your model is not in the above list, we will try to automatically convert the
|
||||
For process-supervised reward models such as `peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly,
|
||||
e.g.: `--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`.
|
||||
|
||||
#### Token Classification
|
||||
|
||||
These models primarily support the [`LLM.encode`](./pooling_models.md#llmencode) API.
|
||||
|
||||
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
|
||||
|--------------|--------|-------------------|-----------------------------|-----------------------------------------|---------------------|
|
||||
| `BertForTokenClassification` | bert-based | `boltuix/NeuroBERT-NER` (see note), etc. | | | ✅︎ |
|
||||
|
||||
!!! note
|
||||
Named Entity Recognition (NER) usage, please refer to <gh-file:examples/offline_inference/pooling/ner.py>, <gh-file:examples/online_serving/pooling/ner.py>.
|
||||
|
||||
[](){ #supported-mm-models }
|
||||
|
||||
## List of Multimodal Language Models
|
||||
|
@ -26,8 +26,14 @@ python examples/offline_inference/pooling/embed_jina_embeddings_v3.py
|
||||
python examples/offline_inference/pooling/embed_matryoshka_fy.py
|
||||
```
|
||||
|
||||
## Named Entity Recognition (NER) usage
|
||||
|
||||
```bash
|
||||
python examples/offline_inference/pooling/ner.py
|
||||
```
|
||||
|
||||
## Qwen3 reranker usage
|
||||
|
||||
```bash
|
||||
python qwen3_reranker.py
|
||||
python examples/offline_inference/pooling/qwen3_reranker.py
|
||||
```
|
||||
|
54
examples/offline_inference/pooling/ner.py
Normal file
54
examples/offline_inference/pooling/ner.py
Normal file
@ -0,0 +1,54 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Adapted from https://huggingface.co/boltuix/NeuroBERT-NER
|
||||
|
||||
from argparse import Namespace
|
||||
|
||||
from vllm import LLM, EngineArgs
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser()
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
# Set example specific arguments
|
||||
parser.set_defaults(
|
||||
model="boltuix/NeuroBERT-NER",
|
||||
runner="pooling",
|
||||
enforce_eager=True,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args: Namespace):
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Barack Obama visited Microsoft headquarters in Seattle on January 2025."
|
||||
]
|
||||
|
||||
# Create an LLM.
|
||||
llm = LLM(**vars(args))
|
||||
tokenizer = llm.get_tokenizer()
|
||||
label_map = llm.llm_engine.vllm_config.model_config.hf_config.id2label
|
||||
|
||||
# Run inference
|
||||
outputs = llm.encode(prompts)
|
||||
|
||||
for prompt, output in zip(prompts, outputs):
|
||||
logits = output.outputs.data
|
||||
predictions = logits.argmax(dim=-1)
|
||||
|
||||
# Map predictions to labels
|
||||
tokens = tokenizer.convert_ids_to_tokens(output.prompt_token_ids)
|
||||
labels = [label_map[p.item()] for p in predictions]
|
||||
|
||||
# Print results
|
||||
for token, label in zip(tokens, labels):
|
||||
if token not in tokenizer.all_special_tokens:
|
||||
print(f"{token:15} → {label}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
@ -12,6 +12,12 @@ python examples/online_serving/pooling/cohere_rerank_client.py
|
||||
python examples/online_serving/pooling/jinaai_rerank_client.py
|
||||
```
|
||||
|
||||
## Named Entity Recognition (NER) usage
|
||||
|
||||
```bash
|
||||
python examples/online_serving/pooling/ner.py
|
||||
```
|
||||
|
||||
## Openai chat embedding for multimodal usage
|
||||
|
||||
```bash
|
||||
|
71
examples/online_serving/pooling/ner.py
Normal file
71
examples/online_serving/pooling/ner.py
Normal file
@ -0,0 +1,71 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Adapted from https://huggingface.co/boltuix/NeuroBERT-NER
|
||||
|
||||
"""
|
||||
Example online usage of Pooling API for Named Entity Recognition (NER).
|
||||
|
||||
Run `vllm serve <model> --runner pooling`
|
||||
to start up the server in vLLM. e.g.
|
||||
|
||||
vllm serve boltuix/NeuroBERT-NER
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import requests
|
||||
import torch
|
||||
|
||||
|
||||
def post_http_request(prompt: dict, api_url: str) -> requests.Response:
|
||||
headers = {"User-Agent": "Test Client"}
|
||||
response = requests.post(api_url, headers=headers, json=prompt)
|
||||
return response
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--model", type=str, default="boltuix/NeuroBERT-NER")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args):
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
|
||||
api_url = f"http://{args.host}:{args.port}/pooling"
|
||||
model_name = args.model
|
||||
|
||||
# Load tokenizer and config
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
label_map = config.id2label
|
||||
|
||||
# Input text
|
||||
text = "Barack Obama visited Microsoft headquarters in Seattle on January 2025."
|
||||
prompt = {"model": model_name, "input": text}
|
||||
|
||||
pooling_response = post_http_request(prompt=prompt, api_url=api_url)
|
||||
|
||||
# Run inference
|
||||
output = pooling_response.json()["data"][0]
|
||||
logits = torch.tensor(output["data"])
|
||||
predictions = logits.argmax(dim=-1)
|
||||
inputs = tokenizer(text, return_tensors="pt")
|
||||
|
||||
# Map predictions to labels
|
||||
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
|
||||
labels = [label_map[p.item()] for p in predictions]
|
||||
assert len(tokens) == len(predictions)
|
||||
|
||||
# Print results
|
||||
for token, label in zip(tokens, labels):
|
||||
if token not in tokenizer.all_special_tokens:
|
||||
print(f"{token:15} → {label}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
39
tests/models/language/pooling/test_token_classification.py
Normal file
39
tests/models/language/pooling/test_token_classification.py
Normal file
@ -0,0 +1,39 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoModelForTokenClassification
|
||||
|
||||
from tests.models.utils import softmax
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["boltuix/NeuroBERT-NER"])
|
||||
# The float32 is required for this tiny model to pass the test.
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@torch.inference_mode
|
||||
def test_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
with vllm_runner(model, max_model_len=None, dtype=dtype) as vllm_model:
|
||||
vllm_outputs = vllm_model.encode(example_prompts)
|
||||
|
||||
with hf_runner(model,
|
||||
dtype=dtype,
|
||||
auto_cls=AutoModelForTokenClassification) as hf_model:
|
||||
tokenizer = hf_model.tokenizer
|
||||
hf_outputs = []
|
||||
for prompt in example_prompts:
|
||||
inputs = tokenizer([prompt], return_tensors="pt")
|
||||
inputs = hf_model.wrap_device(inputs)
|
||||
output = hf_model.model(**inputs)
|
||||
hf_outputs.append(softmax(output.logits[0]))
|
||||
|
||||
# check logits difference
|
||||
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
|
||||
hf_output = torch.tensor(hf_output).cpu().float()
|
||||
vllm_output = torch.tensor(vllm_output).cpu().float()
|
||||
assert torch.allclose(hf_output, vllm_output, 1e-2)
|
@ -414,6 +414,7 @@ _SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = {
|
||||
|
||||
# [Cross-encoder]
|
||||
"BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2"), # noqa: E501
|
||||
"BertForTokenClassification": _HfExamplesInfo("boltuix/NeuroBERT-NER"),
|
||||
"GteNewForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-multilingual-reranker-base", # noqa: E501
|
||||
trust_remote_code=True,
|
||||
hf_overrides={
|
||||
|
@ -943,6 +943,10 @@ class LLM:
|
||||
considered legacy and may be deprecated in the future. You should
|
||||
instead pass them via the `inputs` parameter.
|
||||
"""
|
||||
|
||||
if self.supported_tasks == ["encode"] and pooling_task is None:
|
||||
pooling_task = "encode"
|
||||
|
||||
if pooling_task is None:
|
||||
if "embed" in self.supported_tasks:
|
||||
pooling_task = "embed"
|
||||
|
@ -611,3 +611,55 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors)
|
||||
|
||||
|
||||
@default_pooling_type("ALL")
|
||||
class BertForTokenClassification(nn.Module):
|
||||
is_pooling_model = True
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.head_dtype = vllm_config.model_config.head_dtype
|
||||
self.num_labels = config.num_labels
|
||||
self.bert = BertModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "bert"),
|
||||
embedding_class=BertEmbedding)
|
||||
self.classifier = nn.Linear(config.hidden_size,
|
||||
config.num_labels,
|
||||
dtype=self.head_dtype)
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler({
|
||||
"encode":
|
||||
Pooler.for_encode(pooler_config),
|
||||
})
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
loader = AutoWeightsLoader(self)
|
||||
loaded_params = loader.load_weights(weights)
|
||||
return loaded_params
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if token_type_ids is not None:
|
||||
assert self.bert.config.vocab_size < (1 << TOKEN_TYPE_SHIFT)
|
||||
assert input_ids is not None
|
||||
_encode_token_type_ids(input_ids, token_type_ids)
|
||||
|
||||
hidden_states = self.bert(input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors)
|
||||
|
||||
hidden_states = hidden_states.to(self.head_dtype)
|
||||
return self.classifier(hidden_states)
|
||||
|
@ -193,6 +193,7 @@ _EMBEDDING_MODELS = {
|
||||
|
||||
_CROSS_ENCODER_MODELS = {
|
||||
"BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
|
||||
"BertForTokenClassification": ("bert", "BertForTokenClassification"),
|
||||
"GteNewForSequenceClassification": ("bert_with_rope",
|
||||
"GteNewForSequenceClassification"),
|
||||
"ModernBertForSequenceClassification": ("modernbert",
|
||||
|
@ -720,6 +720,15 @@ class FlexAttentionImpl(AttentionImpl):
|
||||
(query, key, value),
|
||||
)
|
||||
|
||||
query = query[:, :, :num_actual_tokens, :]
|
||||
if ((key_tensor.size(-2) > num_actual_tokens)
|
||||
or (value_tensor.size(-2) > num_actual_tokens)):
|
||||
# In the encoder-only model with torch.compile,
|
||||
# qkv might be padded, which might cause exception.
|
||||
# see: https://github.com/vllm-project/vllm/pull/24872#discussion_r2353252290
|
||||
key_tensor = key_tensor[:, :, :num_actual_tokens, :]
|
||||
value_tensor = value_tensor[:, :, :num_actual_tokens, :]
|
||||
|
||||
else:
|
||||
assert self.attn_type == AttentionType.DECODER
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
@ -744,7 +753,8 @@ class FlexAttentionImpl(AttentionImpl):
|
||||
(query, key_cache, value_cache),
|
||||
)
|
||||
|
||||
query = query[:, :, :num_actual_tokens, :]
|
||||
query = query[:, :, :num_actual_tokens, :]
|
||||
|
||||
# Doesn't work for now -> constraint violation
|
||||
# torch._dynamo.try_mark_dynamic(query, 2)
|
||||
|
||||
|
Reference in New Issue
Block a user