[Misc] Simplify PoolerOutput and move to v1/outputs (#25629)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-09-25 14:47:03 +08:00
committed by GitHub
parent a676e668ee
commit 755ed7b05b
6 changed files with 34 additions and 82 deletions

View File

@ -15,10 +15,10 @@ from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import ExecuteModelRequest, PoolerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.tasks import SupportedTask
from vllm.utils import make_async
from vllm.v1.outputs import SamplerOutput
from vllm.v1.outputs import PoolerOutput, SamplerOutput
from vllm.worker.worker_base import WorkerBase
logger = init_logger(__name__)

View File

@ -16,9 +16,9 @@ from vllm.config import ModelConfig, PoolerConfig, get_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.models.adapters import _load_st_projector
from vllm.pooling_params import PoolingParams
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
from vllm.tasks import PoolingTask
from vllm.utils import current_stream, resolve_obj_by_qualname
from vllm.utils import resolve_obj_by_qualname
from vllm.v1.outputs import PoolerOutput
from vllm.v1.pool.metadata import PoolingCursor, PoolingMetadata
logger = init_logger(__name__)
@ -190,19 +190,6 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
return PoolerClassify()
def build_output(
all_data: Union[torch.Tensor, list[torch.Tensor]], ) -> PoolerOutput:
# Pooling models D2H & synchronize occurs here
if isinstance(all_data, list):
all_data = [d.to("cpu", non_blocking=True) for d in all_data]
else:
all_data = all_data.to("cpu", non_blocking=True)
current_stream().synchronize()
all_outputs = [PoolingSequenceGroupOutput(data) for data in all_data]
return PoolerOutput(outputs=all_outputs)
class PoolingMethod(nn.Module, ABC):
@staticmethod
@ -556,7 +543,7 @@ class SimplePooler(Pooler):
) -> PoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata)
pooled_data = self.head(pooled_data, pooling_metadata)
return build_output(pooled_data)
return pooled_data
class StepPooler(Pooler):
@ -607,7 +594,7 @@ class StepPooler(Pooler):
) -> PoolerOutput:
pooled_data = self.extract_states(hidden_states, pooling_metadata)
pooled_data = self.head(pooled_data, pooling_metadata)
return build_output(pooled_data)
return pooled_data
class ClassifierPooler(Pooler):
@ -678,7 +665,7 @@ class ClassifierPooler(Pooler):
]
# scores shape: [batchsize, num_labels]
return build_output(scores)
return scores
class DispatchPooler(Pooler):
@ -708,7 +695,7 @@ class DispatchPooler(Pooler):
) -> PoolerOutput:
poolers_by_task = self.poolers_by_task
outputs = list[PoolingSequenceGroupOutput]()
outputs = list[torch.Tensor]()
offset = 0
for task, group in groupby(get_tasks(pooling_metadata)):
if not (pooler := poolers_by_task.get(task)):
@ -722,10 +709,10 @@ class DispatchPooler(Pooler):
pooling_metadata[offset:offset + num_items],
)
outputs.extend(group_output.outputs)
outputs.extend(group_output)
offset += num_items
return PoolerOutput(outputs)
return outputs
def extra_repr(self) -> str:
s = f"supported_task={self.get_supported_tasks()}"

View File

@ -12,12 +12,12 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler,
PoolerHead, PoolerNormalize,
PoolingParamsUpdate,
build_output, get_prompt_lens,
get_prompt_lens,
get_prompt_token_ids)
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.sequence import PoolerOutput
from vllm.tasks import PoolingTask
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from vllm.v1.outputs import PoolerOutput
from vllm.v1.pool.metadata import PoolingMetadata
from .interfaces_base import default_pooling_type
@ -212,7 +212,7 @@ class GritLMPooler(Pooler):
) -> PoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata)
pooled_data = self.head(pooled_data, pooling_metadata)
return build_output(pooled_data)
return pooled_data
@default_pooling_type("MEAN")

View File

@ -11,7 +11,6 @@ if TYPE_CHECKING:
from vllm.v1.worker.kv_connector_model_runner_mixin import (
KVConnectorOutput)
else:
LoRARequest = Any
KVConnectorOutput = Any
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
@ -48,29 +47,6 @@ class RequestMetrics:
model_execute_time: Optional[float] = None
class PoolingSequenceGroupOutput(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
array_like=True, # type: ignore[call-arg]
):
"""The model output associated with a pooling sequence group."""
# Annotated as Any to be compatible with msgspec
# The actual type is in SequenceGroup.pooled_data
data: Any
def get_data_nbytes(self) -> int:
data: torch.Tensor = self.data
return data.nbytes
def __repr__(self) -> str:
return f"PoolingSequenceGroupOutput(data={self.data}"
def __eq__(self, other: object) -> bool:
if not isinstance(other, PoolingSequenceGroupOutput):
raise NotImplementedError()
return self.data == other.data
# cannot use msgspec.Struct here because Dynamo does not support it
@dataclass
class IntermediateTensors:
@ -119,30 +95,6 @@ class IntermediateTensors:
return f"IntermediateTensors(tensors={self.tensors})"
class PoolerOutput(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
array_like=True): # type: ignore[call-arg]
"""The output from a pooling operation in the pooling model."""
outputs: list[PoolingSequenceGroupOutput]
def get_data_nbytes(self) -> int:
return sum(o.get_data_nbytes() for o in self.outputs)
def __getitem__(self, idx: int) -> PoolingSequenceGroupOutput:
return self.outputs[idx]
def __setitem__(self, idx: int, value: PoolingSequenceGroupOutput):
self.outputs[idx] = value
def __len__(self):
return len(self.outputs)
def __eq__(self, other: object):
return isinstance(other,
self.__class__) and self.outputs == other.outputs
class ExecuteModelRequest(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]

View File

@ -3,7 +3,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, NamedTuple, Optional
from typing import TYPE_CHECKING, NamedTuple, Optional, Union
import torch
@ -65,6 +65,11 @@ class LogprobsTensors(NamedTuple):
)
# [num_reqs, <dynamic>]
# The shape of each element depends on the pooler used
PoolerOutput = Union[torch.Tensor, list[torch.Tensor]]
@dataclass
class SamplerOutput:

View File

@ -52,13 +52,14 @@ from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargsItem,
from vllm.multimodal.utils import group_mm_kwargs_by_modality
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.sequence import IntermediateTensors
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes, cdiv, check_use_alibi, get_dtype_size,
is_pin_memory_available,
length_from_prompt_token_ids_or_embeds, round_up,
supports_dynamo)
from vllm.utils.jsontree import json_map_leaves
from vllm.v1.attention.backends.flash_attn import AttentionMetadata
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
from vllm.v1.attention.backends.utils import (
@ -79,7 +80,7 @@ from vllm.v1.kv_cache_interface import (AttentionSpec,
# yapf: enable
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
DraftTokenIds, LogprobsLists, LogprobsTensors,
ModelRunnerOutput, SamplerOutput)
ModelRunnerOutput, PoolerOutput, SamplerOutput)
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
from vllm.v1.sample.metadata import SamplingMetadata
@ -1823,15 +1824,22 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
device=hidden_states.device)
seq_lens_cpu = self.seq_lens.cpu[:self.input_batch.num_reqs]
# Pooling models D2H & synchronize occurs in pooler.py:build_output
raw_pooler_output = self.model.pooler(
hidden_states=hidden_states, pooling_metadata=pooling_metadata)
model = cast(VllmModelForPooling, self.model)
raw_pooler_output: PoolerOutput = model.pooler(
hidden_states=hidden_states,
pooling_metadata=pooling_metadata,
)
raw_pooler_output = json_map_leaves(
lambda x: x.to("cpu", non_blocking=True),
raw_pooler_output,
)
self._sync_device()
pooler_output: list[Optional[torch.Tensor]] = []
for raw_output, seq_len, prompt_len in zip(
raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens):
output = raw_output.data if seq_len == prompt_len else None
output = raw_output if seq_len == prompt_len else None
pooler_output.append(output)
return ModelRunnerOutput(
@ -3233,7 +3241,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
for task in self.get_supported_pooling_tasks():
# Run a full batch with each task to ensure none of them OOMs
output = self._dummy_pooler_run_task(hidden_states, task)
output_size[task] = output.get_data_nbytes()
output_size[task] = sum(o.nbytes for o in output)
del output # Allow GC
max_task = max(output_size.items(), key=lambda x: x[1])[0]