[Bugfix][v1] Fix step pooler implementation and step pooling usage in v1 (#19956)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py
2025-06-24 02:38:06 +08:00
committed by GitHub
parent 68aaeb3749
commit 61f4fc5dc6
14 changed files with 164 additions and 40 deletions

View File

@ -1027,13 +1027,13 @@ class VllmRunner:
req_outputs = self.model.classify(prompts)
return [req_output.outputs.probs for req_output in req_outputs]
def encode(self,
prompts: list[str],
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
*args,
**kwargs) -> list[list[float]]:
def embed(self,
prompts: list[str],
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
*args,
**kwargs) -> list[list[float]]:
inputs = self.get_inputs(prompts,
images=images,
videos=videos,
@ -1042,6 +1042,10 @@ class VllmRunner:
req_outputs = self.model.embed(inputs, *args, **kwargs)
return [req_output.outputs.embedding for req_output in req_outputs]
def encode(self, prompts: list[str]) -> list[list[float]]:
req_outputs = self.model.encode(prompts)
return [req_output.outputs.data for req_output in req_outputs]
def score(
self,
text_1: Union[str, list[str]],

View File

@ -29,8 +29,8 @@ def test_model_loading_with_params(vllm_runner):
revision=REVISION,
dtype="float16",
max_model_len=MAX_MODEL_LEN) as vllm_model:
output = vllm_model.encode("Write a short story about a robot that"
" dreams for the first time.\n")
output = vllm_model.embed("Write a short story about a robot that"
" dreams for the first time.\n")
model_config = vllm_model.model.llm_engine.model_config
model_tokenizer = vllm_model.model.llm_engine.tokenizer
@ -67,8 +67,8 @@ def test_roberta_model_loading_with_params(vllm_runner):
revision=REVISION_ROBERTA,
dtype="float16",
max_model_len=MAX_MODEL_LEN) as vllm_model:
output = vllm_model.encode("Write a short story about a robot that"
" dreams for the first time.\n")
output = vllm_model.embed("Write a short story about a robot that"
" dreams for the first time.\n")
model_config = vllm_model.model.llm_engine.model_config
model_tokenizer = vllm_model.model.llm_engine.tokenizer
@ -105,8 +105,8 @@ def test_facebook_roberta_model_loading_with_params(vllm_runner):
with vllm_runner(model_name=model_name,
dtype="float16",
max_model_len=MAX_MODEL_LEN) as vllm_model:
output = vllm_model.encode("Write a short story about a robot that"
" dreams for the first time.\n")
output = vllm_model.embed("Write a short story about a robot that"
" dreams for the first time.\n")
model_tokenizer = vllm_model.model.llm_engine.tokenizer
assert model_tokenizer.tokenizer_id == model_name

View File

@ -55,7 +55,7 @@ def correctness_test_embed_models(hf_runner,
task="embed",
max_model_len=None,
**vllm_extra_kwargs) as vllm_model:
vllm_outputs = vllm_model.encode(example_prompts)
vllm_outputs = vllm_model.embed(example_prompts)
with hf_runner(
model_info.name,

View File

@ -89,7 +89,7 @@ def test_models(
task="embed",
max_model_len=512,
**vllm_extra_kwargs) as vllm_model:
vllm_outputs = vllm_model.encode(example_prompts)
vllm_outputs = vllm_model.embed(example_prompts)
check_embeddings_close(
embeddings_0_lst=hf_outputs,

View File

@ -98,11 +98,11 @@ def test_matryoshka(
if dimensions not in matryoshka_dimensions:
with pytest.raises(ValueError):
vllm_model.encode(
vllm_model.embed(
example_prompts,
pooling_params=PoolingParams(dimensions=dimensions))
else:
vllm_outputs = vllm_model.encode(
vllm_outputs = vllm_model.embed(
example_prompts,
pooling_params=PoolingParams(dimensions=dimensions))

View File

@ -0,0 +1,104 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
import torch.nn.functional as F
from transformers import AutoModel
from vllm.platforms import current_platform
from ....conftest import HfRunner
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
@pytest.fixture
def math_step_prompts():
# ruff: noqa: E501
data = {
"system":
"Please reason step by step, and put your final answer within \\boxed{}. ",
"query":
"Sue lives in a fun neighborhood. One weekend, the neighbors decided to play a prank on Sue. On Friday morning, the neighbors placed 18 pink plastic flamingos out on Sue's front yard. On Saturday morning, the neighbors took back one third of the flamingos, painted them white, and put these newly painted white flamingos back out on Sue's front yard. Then, on Sunday morning, they added another 18 pink plastic flamingos to the collection. At noon on Sunday, how many more pink plastic flamingos were out than white plastic flamingos?",
"response": [
"To find out how many more pink plastic flamingos were out than white plastic flamingos at noon on Sunday, we can break down the problem into steps. First, on Friday, the neighbors start with 18 pink plastic flamingos.",
"On Saturday, they take back one third of the flamingos. Since there were 18 flamingos, (1/3 \\times 18 = 6) flamingos are taken back. So, they have (18 - 6 = 12) flamingos left in their possession. Then, they paint these 6 flamingos white and put them back out on Sue's front yard. Now, Sue has the original 12 pink flamingos plus the 6 new white ones. Thus, by the end of Saturday, Sue has (12 + 6 = 18) pink flamingos and 6 white flamingos.",
"On Sunday, the neighbors add another 18 pink plastic flamingos to Sue's front yard. By the end of Sunday morning, Sue has (18 + 18 = 36) pink flamingos and still 6 white flamingos.",
"To find the difference, subtract the number of white flamingos from the number of pink flamingos: (36 - 6 = 30). Therefore, at noon on Sunday, there were 30 more pink plastic flamingos out than white plastic flamingos. The answer is (\\boxed{30}).",
],
}
answer = "<extra_0>".join(data['response']) + "<extra_0>"
prompt = f"<im_start>system\n{data['system']}<im_end>\n<im_start>user\n{data['query']}<im_end>\n<im_start>assistant\n{answer}<im_end><|endoftext|>"
return [prompt]
def step_reward_patch_hf_model(hf_model: HfRunner):
# Patch the hf_runner to use the step reward function
def make_step_rewards(logits: torch.Tensor,
token_masks: torch.Tensor) -> list[list[float]]:
probabilities = F.softmax(logits, dim=-1)
probabilities = probabilities * token_masks.unsqueeze(-1)
all_scores_res: list[list[float]] = []
for i in range(probabilities.size(0)):
sample = probabilities[i] # seq_len, num_labels
positive_probs = sample[sample != 0].view(-1, 2)
non_zero_elements_list = positive_probs.cpu().tolist()
all_scores_res.append(non_zero_elements_list)
return all_scores_res
def reward(prompts: list[str]) -> list[list[float]]:
input_ids = hf_model.tokenizer(prompts, return_tensors="pt").input_ids
input_ids = hf_model.wrap_device(input_ids)
outputs = hf_model.model(input_ids=input_ids)
step_sep_id = hf_model.tokenizer.encode("<extra_0>")[0]
token_masks = (input_ids == step_sep_id)
return make_step_rewards(outputs[0], token_masks)
hf_model.reward = reward # type: ignore[attr-defined]
return hf_model
@pytest.mark.parametrize(
"model",
[
pytest.param("Qwen/Qwen2.5-Math-PRM-7B",
marks=[pytest.mark.core_model, pytest.mark.cpu_model]),
],
)
@pytest.mark.parametrize("dtype", ["half"])
def test_prm_models(
hf_runner,
vllm_runner,
math_step_prompts,
model: str,
dtype: str,
monkeypatch,
) -> None:
if current_platform.is_rocm():
# ROCm Triton FA does not currently support sliding window attention
# switch to use ROCm CK FA backend
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False")
with vllm_runner(model, max_model_len=1024, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.encode(math_step_prompts)
with hf_runner(model, dtype=dtype, auto_cls=AutoModel) as hf_model:
hf_model = step_reward_patch_hf_model(hf_model)
hf_outputs = hf_model.reward(math_step_prompts)
# check logits difference
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
hf_output = torch.tensor(hf_output)
vllm_output = torch.tensor(vllm_output)
assert torch.allclose(hf_output, vllm_output, 1e-2)

View File

@ -98,7 +98,7 @@ def _run_test(
max_model_len=8192) as vllm_model:
tokenizer = vllm_model.model.get_tokenizer()
texts = [
# this is necessary because vllm_model.encode will not apply any
# this is necessary because vllm_model.embed will not apply any
# templating to the prompt, and therefore lacks an image_pad
# token unless one is inserted beforehand (the (28,28) image
# above is converted to an image pad token by the chat template).
@ -109,7 +109,7 @@ def _run_test(
# vllm will replace the pad token with the actual image,
# which may be a placeholder image, later.
]
vllm_outputs = vllm_model.encode(texts, images=input_images)
vllm_outputs = vllm_model.embed(texts, images=input_images)
hf_outputs = []
with hf_runner(model,

View File

@ -68,7 +68,7 @@ def _run_test(
dtype=dtype,
max_model_len=4096,
enforce_eager=True) as vllm_model:
vllm_outputs = vllm_model.encode(input_texts, images=input_images)
vllm_outputs = vllm_model.embed(input_texts, images=input_images)
with hf_runner(model, dtype=dtype,
auto_cls=AutoModelForImageTextToText) as hf_model:

View File

@ -46,7 +46,7 @@ def _run_test(
# will hurt multiprocessing backend with fork method (the default method).
with vllm_runner(model, task="embed", dtype=dtype,
enforce_eager=True) as vllm_model:
vllm_outputs = vllm_model.encode(input_texts, images=input_images)
vllm_outputs = vllm_model.embed(input_texts, images=input_images)
# use eager mode for hf runner, since phi3_v didn't work with flash_attn
hf_model_kwargs = {"_attn_implementation": "eager"}

View File

@ -161,7 +161,7 @@ def test_4bit_bnb_embedding_model(
dtype=dtype,
gpu_memory_utilization=0.5,
quantization="bitsandbytes") as vllm_model:
vllm_outputs = vllm_model.encode(example_prompts)
vllm_outputs = vllm_model.embed(example_prompts)
check_embeddings_close(
embeddings_0_lst=hf_outputs,
embeddings_1_lst=vllm_outputs,

View File

@ -239,25 +239,24 @@ class StepPool(SimplePooler):
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
prompt_token_ids = self.get_prompt_token_ids(pooling_metadata)
pooled_data: list[torch.Tensor] = []
pooled_data_lst = list[torch.Tensor]()
if isinstance(hidden_states, list):
for req_state, prompt_len in zip(hidden_states, prompt_lens):
assert prompt_len == req_state.shape[0], \
"partial prefill not supported with mean pooling"
pooled_data = hidden_states
"partial prefill not supported with step pooling"
pooled_data_lst = hidden_states
else:
offset = 0
for prompt_len in prompt_lens:
pooled_data_i = hidden_states[offset:offset + prompt_len]
offset += prompt_len
pooled_data.append(pooled_data_i)
pooled_data_lst.append(pooled_data_i)
pooled_data = []
pooled_data = list[torch.Tensor]()
returned_token_ids = self.returned_token_ids
step_tag_id = self.step_tag_id
for data, token_id in zip(pooled_data, prompt_token_ids):
for data, token_id in zip(pooled_data_lst, prompt_token_ids):
if returned_token_ids is not None and len(returned_token_ids) > 0:
data = data[:, returned_token_ids]

View File

@ -489,6 +489,12 @@ def supports_cross_encoding(
return is_pooling_model(model) and _supports_cross_encoding(model)
def has_step_pooler(model: Union[type[object], object]) -> bool:
"""Check if the model uses step pooler."""
return is_pooling_model(model) and any(
type(module).__name__ == "StepPool" for module in model.modules())
class SupportsQuant:
"""The interface required for all models that support quantization."""

View File

@ -59,14 +59,15 @@ class CachedRequestState:
class InputBatch:
def __init__(
self,
max_num_reqs: int,
max_model_len: int,
max_num_batched_tokens: int,
device: torch.device,
pin_memory: bool,
vocab_size: int,
block_sizes: list[int], # The block_size of each kv cache group
self,
max_num_reqs: int,
max_model_len: int,
max_num_batched_tokens: int,
device: torch.device,
pin_memory: bool,
vocab_size: int,
block_sizes: list[int], # The block_size of each kv cache group
logits_processing_needs_token_ids: bool = False,
):
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
@ -74,6 +75,8 @@ class InputBatch:
self.device = device
self.pin_memory = pin_memory
self.vocab_size = vocab_size
self.logits_processing_needs_token_ids = (
logits_processing_needs_token_ids)
self._req_ids: list[Optional[str]] = []
self.req_id_to_index: dict[str, int] = {}
@ -579,9 +582,14 @@ class InputBatch:
copy_slice(self.repetition_penalties_cpu_tensor,
self.repetition_penalties, num_reqs)
# The prompt tokens are used only for applying penalties during
# the sampling process. Hence copy these tensors only when
# there are requests which need penalties to be applied.
needs_prompt_token_ids = (not self.no_penalties or
(self.num_reqs > 0
and self.logits_processing_needs_token_ids))
if needs_prompt_token_ids:
# The prompt tokens are used only for applying penalties or
# step pooling during the sampling/pooling process.
# Hence copy these tensors only when there are requests which
# need penalties/step_pooler to be applied.
prompt_token_ids = self._make_prompt_token_ids_tensor()
else:
prompt_token_ids = None

View File

@ -33,6 +33,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
from vllm.model_executor.models.interfaces import has_step_pooler
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.multimodal.utils import group_mm_inputs_by_modality
@ -1708,6 +1709,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
model_loader.load_weights(self.model,
model_config=self.model_config)
if has_step_pooler(self.model):
self.input_batch.logits_processing_needs_token_ids = True
if self.lora_config:
self.model = self.load_lora_model(self.model,
self.model_config,